本文整理汇总了Python中tensorflow.python.framework.ops.get_default_graph函数的典型用法代码示例。如果您正苦于以下问题:Python get_default_graph函数的具体用法?Python get_default_graph怎么用?Python get_default_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_default_graph函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testParallelApplyGradMean
def testParallelApplyGradMean(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
accum_ops = []
for x in elems:
x = _indexedslice(np.array([[x, 0], [0, x]]).astype(np.float32))
accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
takeg_t = q.take_indexed_slices_grad(1)
def apply_indexed_slices_grad(accum_op):
self.evaluate(accum_op)
threads = [
self.checkedThread(
target=apply_indexed_slices_grad, args=(o,)) for o in accum_ops
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
val = self.evaluate(takeg_t)
expected_val = sum(elems) / len(elems)
self._assertEqual_nparray(
np.array([[expected_val, 0], [0, expected_val]]).astype(np.float32),
val, sess)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:sparse_conditional_accumulator_test.py
示例2: before_run
def before_run(self, run_context):
""" Dumps graphs and loads checkpoint if there exits.
Called before each call to run().
Args:
run_context: A `SessionRunContext` object.
Returns: A `SessionRunArgs` object containing global_step.
"""
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
if self._is_chief and self._first_call:
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir,
"graph.pbtxt")
# dump model details "model_analysis.txt"
dump_model_analysis(self._checkpoint_dir) # dump model configs
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=self._saver.saver_def)
if self._summary_writer is not None:
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
self._first_call = False
return tf.train.SessionRunArgs(self._global_step)
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:30,代码来源:hooks.py
示例3: wrapped_body
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
Args:
loop_counter: Loop counter which needs to be incremented in the body.
*args: List of args
Returns:
A list of tensors the same length as args.
"""
# Capture the tensors already captured in cond_graph so that they appear
# in the same order in body_graph.external_captures.
for t in cond_graph.external_captures:
ops.get_default_graph().capture(t)
# Convert the flow variables in `args` to TensorArrays. `args` should
# already have the same structure as `orig_loop_vars` but currently there
# is no nest.zip so we call `_pack_sequence_as` which flattens both
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
# and packs it into the structure of `orig_loop_vars`.
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
if not nest.is_sequence(outputs):
outputs = [outputs]
# Compare the structure of input and output of body converting the
# top-level tuples to list to be compatible with legacy while_loop.
nest.assert_same_structure(list(outputs), list(orig_loop_vars))
outputs = _tensor_array_to_flow(outputs)
# TODO(srbs): Update lowering code to create _Enter nodes with
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1] + list(outputs)
开发者ID:ziky90,项目名称:tensorflow,代码行数:32,代码来源:while_v2.py
示例4: get_seed
def get_seed(op_seed):
"""Returns the local seeds an operation should use given an op-specific seed.
Given operation-specific seed, `op_seed`, this helper function returns two
seeds derived from graph-level and op-level seeds. Many random operations
internally use the two seeds to allow user to change the seed globally for a
graph, or for only specific operations.
For details on how the graph-level seed interacts with op seeds, see
@{tf.set_random_seed}.
Args:
op_seed: integer.
Returns:
A tuple of two integers that should be used for the local seed of this
operation.
"""
graph_seed = ops.get_default_graph().seed
if graph_seed is not None:
if op_seed is None:
# pylint: disable=protected-access
op_seed = ops.get_default_graph()._last_id
seeds = _truncate_seed(graph_seed), _truncate_seed(op_seed)
else:
if op_seed is not None:
seeds = DEFAULT_GRAPH_SEED, _truncate_seed(op_seed)
else:
seeds = None, None
# Avoid (0, 0) as the C++ ops interpret it as nondeterminism, which would
# be unexpected since Python docs say nondeterminism is (None, None).
if seeds == (0, 0):
return (0, _MAXINT32)
return seeds
开发者ID:1000sprites,项目名称:tensorflow,代码行数:34,代码来源:random_seed.py
示例5: testParallelUpdateWithoutLocking
def testParallelUpdateWithoutLocking(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess:
ones_t = array_ops.fill([1024, 1024], 1.0)
p = variables.Variable(array_ops.zeros([1024, 1024]))
adds = [
state_ops.assign_add(
p, ones_t, use_locking=False) for _ in range(20)
]
self.evaluate(variables.global_variables_initializer())
def run_add(add_op):
self.evaluate(add_op)
threads = [
self.checkedThread(
target=run_add, args=(add_op,)) for add_op in adds
]
for t in threads:
t.start()
for t in threads:
t.join()
vals = self.evaluate(p)
ones = np.ones((1024, 1024)).astype(np.float32)
self.assertTrue((vals >= ones).all())
self.assertTrue((vals <= ones * 20).all())
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:29,代码来源:dense_update_ops_no_tsan_test.py
示例6: testIteratorStringHandleReuseTensorObject
def testIteratorStringHandleReuseTensorObject(self):
dataset = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
one_shot_iterator = dataset_ops.make_one_shot_iterator(dataset)
initializable_iterator = dataset_ops.make_initializable_iterator(dataset)
structure_iterator = iterator_ops.Iterator.from_structure(
dataset.output_types)
created_ops = len(ops.get_default_graph().get_operations())
self.assertIs(one_shot_iterator.string_handle(),
one_shot_iterator.string_handle())
self.assertIs(initializable_iterator.string_handle(),
initializable_iterator.string_handle())
self.assertIs(structure_iterator.string_handle(),
structure_iterator.string_handle())
# Assert that getting the (default) string handle creates no ops.
self.assertEqual(created_ops, len(ops.get_default_graph().get_operations()))
# Specifying an explicit name will create a new op.
handle_with_name = one_shot_iterator.string_handle(name="foo")
self.assertEqual("foo", handle_with_name.op.name)
self.assertIsNot(one_shot_iterator.string_handle(), handle_with_name)
handle_with_same_name = one_shot_iterator.string_handle(name="foo")
self.assertEqual("foo_1", handle_with_same_name.op.name)
self.assertIsNot(handle_with_name, handle_with_same_name)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:iterator_test.py
示例7: _testDefaultGraphInThread
def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
with session.Session() as s:
self.assertEqual(ops.get_default_graph(), s.graph)
a = constant_op.constant(1.0, shape=[1, 2])
b = constant_op.constant(2.0, shape=[2, 3])
c = math_ops.matmul(a, b)
v = variables.Variable(c, name='var_%d' % i)
# Block here until all threads have constructed their graph.
constructed_event.set()
continue_event.wait()
assign_c_to_v = state_ops.assign(v, c)
v.initializer.run()
assign_c_to_v.eval()
v_val = v.eval()
self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
d = constant_op.constant(3.0, shape=[2, 3])
e = math_ops.matmul(a, d)
assign_e_to_v = state_ops.assign(v, e)
e_val = e.eval()
self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
v_val = v.eval()
self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
s.run(assign_e_to_v)
v_val = v.eval()
self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
self.assertEqual(ops.get_default_graph(), s.graph)
开发者ID:agouwin,项目名称:udacity_deep_learning_homework,代码行数:28,代码来源:session_test.py
示例8: test_assign_stays_in_true_dtype
def test_assign_stays_in_true_dtype(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
# small_val is a value such that 1.0 + small_val == 1.0 in fp16, but not
# in fp32
small_val = np.finfo('float16').eps / 2
small_tensor = constant_op.constant(small_val, dtype=dtypes.float32)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
# Variable should be increased, despite it appearing to be the same
# float16 value.
self.assertEqual(1. + small_val,
self.evaluate(x.assign(1. + small_tensor)))
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x.value()))
self.evaluate(x.assign(1.))
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(1. + small_val,
self.evaluate(x.assign_add(small_tensor)))
self.assertEqual(1., self.evaluate(x.value()))
self.assertEqual(1. + small_val, self.evaluate(x.value()))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:25,代码来源:autocast_variable_test.py
示例9: test_read
def test_read(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
# outside of auto cast scope.
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.value().dtype, dtypes.float32)
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
# within auto cast scope of different dtype
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(x.dtype, dtypes.float16)
self.assertEqual(x.value().dtype, dtypes.float16)
self.assertEqual(x.read_value().dtype, dtypes.float16)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float16)
# within auto cast scope of same dtype
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float32):
self.assertEqual(x.dtype, dtypes.float32)
self.assertEqual(x.value().dtype, dtypes.float32)
self.assertEqual(x.read_value().dtype, dtypes.float32)
self.assertEqual(array_ops.identity(x).dtype, dtypes.float32)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:autocast_variable_test.py
示例10: testAccumulatorApplyAndBlockingTake
def testAccumulatorApplyAndBlockingTake(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0, 30.0]
elems_ave = sum(elems) / len(elems)
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(3)
def apply_grad():
time.sleep(1.0)
for accum_op in accum_ops:
self.evaluate(accum_op)
return_array = []
def take_grad():
return_array.append(self.evaluate(takeg_t))
accum_thread = self.checkedThread(target=apply_grad)
takeg_thread = self.checkedThread(target=take_grad)
accum_thread.start()
takeg_thread.start()
accum_thread.join()
takeg_thread.join()
self.assertEqual([elems_ave], return_array)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:conditional_accumulator_test.py
示例11: test_operator_overloads
def test_operator_overloads(self, distribute):
with get_distribute_scope(distribute):
x = get_var(1., dtypes.float32)
x = get_autocast_var(x, distribute)
self.evaluate(x.initializer)
v1 = constant_op.constant(2., dtype=dtypes.float32)
v2 = constant_op.constant(2., dtype=dtypes.float16)
# Because autocast variables do not yet define operator overloads, the
# operator is defined by the non-variable tensor
# Test variable as the LHS. Currently, this is not supported with
# distributed autocast variables
if not distribute:
self.assertEqual(self.evaluate(x + v1), 3.)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(self.evaluate(x + v2), 3.)
# Test variable as the RHS
self.assertEqual(self.evaluate(v1 + x), 3.)
with ops.get_default_graph()._enable_auto_casting_variables(
dtypes.float16):
self.assertEqual(self.evaluate(v2 + x), 3.)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:autocast_variable_test.py
示例12: testParallelTakeGrad
def testParallelTakeGrad(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [e for e in range(10)]
accum_ops = [q.apply_grad((np.float32(e),), local_step=e) for e in elems]
takeg_t = q.take_grad(1)
def apply_grad():
for accum_op in accum_ops:
time.sleep(1.0)
self.evaluate(accum_op)
apply_grad_thread = self.checkedThread(target=apply_grad)
results = []
def take_grad():
results.append(self.evaluate(takeg_t))
threads = [self.checkedThread(target=take_grad) for _ in range(10)]
for thread in threads:
thread.start()
apply_grad_thread.start()
for thread in threads:
thread.join()
apply_grad_thread.join()
self.assertItemsEqual(elems, results)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:34,代码来源:conditional_accumulator_test.py
示例13: testParallelApplyGrad
def testParallelApplyGrad(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess:
q = data_flow_ops.ConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([1]))
elems = [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0, 100.0]
accum_ops = [q.apply_grad((x,), local_step=0) for x in elems]
takeg_t = q.take_grad(1)
def apply_grad(accum_op):
self.evaluate(accum_op)
threads = [
self.checkedThread(
target=apply_grad, args=(o,)) for o in accum_ops
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
val = self.evaluate(takeg_t)
self.assertEqual(val, sum(elems) / len(elems))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:conditional_accumulator_test.py
示例14: testAccumulatorApplyAndBlockingTake
def testAccumulatorApplyAndBlockingTake(self):
# We need each thread to keep its own device stack or the device scopes
# won't be properly nested.
ops.get_default_graph().switch_to_thread_local()
with self.cached_session() as sess:
q = data_flow_ops.SparseConditionalAccumulator(
dtypes_lib.float32, name="Q", shape=tensor_shape.TensorShape([2, 2]))
elems = [10.0, 20.0, 30.0]
elems_ave = sum(elems) / len(elems)
accum_ops = []
for x in elems:
x = _indexedslice(np.array([[0, x], [0, 0]]).astype(np.float32))
accum_ops.append(q.apply_indexed_slices_grad(x, local_step=0))
takeg_t = q.take_indexed_slices_grad(3)
results = []
def apply_indexed_slices_grad():
for accum_op in accum_ops:
self.evaluate(accum_op)
def take_grad():
results.append(self.evaluate(takeg_t))
accum_thread = self.checkedThread(target=apply_indexed_slices_grad)
takeg_thread = self.checkedThread(target=take_grad)
accum_thread.start()
takeg_thread.start()
accum_thread.join()
takeg_thread.join()
self._assertEqual_nparray([[0, elems_ave], [0, 0]], results[0], sess)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:33,代码来源:sparse_conditional_accumulator_test.py
示例15: finalize
def finalize(self):
"""Creates operations if needed and finalizes the graph."""
if self._global_step_tensor is None:
self._global_step_tensor = contrib_variables.get_or_create_global_step()
if self._init_op is None:
self._init_op = Scaffold._get_or_default(
'init_op', ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
if self._ready_op is None:
self._ready_op = Scaffold._get_or_default(
'ready_op', ops.GraphKeys.READY_OP,
variables.report_uninitialized_variables)
if self._local_init_op is None:
self._local_init_op = Scaffold._get_or_default(
'local_init_op', ops.GraphKeys.LOCAL_INIT_OP,
Scaffold._default_local_init_op)
if self._summary_op is None:
self._summary_op = Scaffold._get_or_default(
'summary_op', ops.GraphKeys.SUMMARY_OP,
logging_ops.merge_all_summaries)
# pylint: disable=g-long-lambda
if self._saver is None:
self._saver = Scaffold._get_or_default(
'saver',
ops.GraphKeys.SAVERS,
lambda: training_saver.Saver(sharded=True,
max_to_keep=self._keep_checkpoint_max))
# pylint: enable=g-long-lambda
ops.get_default_graph().finalize()
开发者ID:10imaging,项目名称:tensorflow,代码行数:29,代码来源:supervised_session.py
示例16: container
def container(self, container_name):
"""Returns a context manager that specifies the resource container to use.
Overridden from `tf.Graph` to update both the init_scope container
and the present inner container. This is necessary to make sure setting
containers applies correctly both to created variables and to stateful
ops.
Args:
container_name: container name string.
Returns:
A context manager for defining resource containers for stateful ops,
yields the container name.
"""
original_container = self._container
# pylint: disable=protected-access
with ops.init_scope():
original_init_container = ops.get_default_graph()._container
try:
self._container = container_name
with ops.init_scope():
ops.get_default_graph()._container = container_name
yield self._container
finally:
self._container = original_container
with ops.init_scope():
ops.get_default_graph()._container = original_init_container
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:function.py
示例17: test_graph_replace_gradients
def test_graph_replace_gradients(self):
ops.reset_default_graph()
w = variables.VariableV1(0.0, name="w")
y = math_ops.multiply(math_ops.multiply(w, w, name="mul1"), w, name="mul2")
g = gradients_impl.gradients(y, w, name="grad")[0]
# Extract the operations.
replacement_ts = {w.value(): g}
original_mul1_grad = (ops.get_default_graph().
get_operation_by_name("grad/mul1_grad/Mul_1"))
# Should not raise exception.
res = ge.graph_replace(g, replacement_ts, dst_scope="res")
# Extract the operations after graph_replace.
result_mul1_grad = (ops.get_default_graph().
get_operation_by_name("res/grad/mul1_grad/Mul_1"))
# Make sure _original_ops are as expected.
self.assertEqual(original_mul1_grad._original_op.name, u"mul1")
self.assertEqual(result_mul1_grad._original_op.name, u"res/mul1")
self.assertNotEqual(res.name, g.name)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
g_val, res_val = sess.run([g, res])
self.assertNear(g_val, 0.0, ERROR_TOLERANCE)
self.assertNear(res_val, 0.0, ERROR_TOLERANCE)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:27,代码来源:transform_test.py
示例18: copy_scoped_meta_graph
def copy_scoped_meta_graph(from_scope, to_scope,
from_graph=None, to_graph=None):
"""Copies a sub-meta_graph from one scope to another.
Args:
from_scope: `String` name scope containing the subgraph to be copied.
to_scope: `String` name scope under which the copied subgraph will reside.
from_graph: Optional `Graph` from which to copy the subgraph. If `None`, the
default graph is use.
to_graph: Optional `Graph` to which to copy the subgraph. If `None`, the
default graph is used.
Returns:
A dictionary of `Variables` that has been copied into `to_scope`.
Raises:
ValueError: If `from_scope` and `to_scope` are the same while
`from_graph` and `to_graph` are also the same.
"""
from_graph = from_graph or ops.get_default_graph()
to_graph = to_graph or ops.get_default_graph()
if from_graph == to_graph and from_scope == to_scope:
raise ValueError("'from_scope' and 'to_scope' need to be different "
"when performing copy in the same graph.")
orig_meta_graph, var_list = export_scoped_meta_graph(
export_scope=from_scope, graph=from_graph)
var_list = import_scoped_meta_graph(orig_meta_graph,
graph=to_graph,
import_scope=to_scope)
return var_list
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:32,代码来源:meta_graph.py
示例19: testPartitionConcatenatesAlongCorrectAxis
def testPartitionConcatenatesAlongCorrectAxis(self):
def _part_axis_0(**unused_kwargs):
return (2, 1, 1)
def _part_axis_1(**unused_kwargs):
return (1, 2, 1)
with variable_scope.variable_scope("root"):
v0 = variable_scope.get_variable(
"n0", shape=(2, 2, 2), partitioner=_part_axis_0)
v1 = variable_scope.get_variable(
"n1", shape=(2, 2, 2), partitioner=_part_axis_1)
self.assertEqual(v0.get_shape(), (2, 2, 2))
self.assertEqual(v1.get_shape(), (2, 2, 2))
n0_0 = ops.get_default_graph().get_tensor_by_name("root/n0/part_0:0")
n0_1 = ops.get_default_graph().get_tensor_by_name("root/n0/part_1:0")
self.assertEqual(n0_0.get_shape(), (1, 2, 2))
self.assertEqual(n0_1.get_shape(), (1, 2, 2))
n1_0 = ops.get_default_graph().get_tensor_by_name("root/n1/part_0:0")
n1_1 = ops.get_default_graph().get_tensor_by_name("root/n1/part_1:0")
self.assertEqual(n1_0.get_shape(), (2, 1, 2))
self.assertEqual(n1_1.get_shape(), (2, 1, 2))
开发者ID:Y-owen,项目名称:tensorflow,代码行数:26,代码来源:variable_scope_test.py
示例20: _parse_kwargs_as_attrs
def _parse_kwargs_as_attrs(func_name, **kwargs):
"""Parses **kwargs into a node's attributes."""
attrs = {}
noinline = kwargs.pop("noinline", None)
if noinline is not None:
attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline))
compiled = kwargs.pop("compiled", None)
separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None)
if compiled is not None:
attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled))
attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue(
b=bool(separate_compiled_gradients))
# Forward _XlaScope from enclosing context (if set), otherwise create new.
# pylint: disable=protected-access
if "_XlaScope" in ops.get_default_graph()._attr_scope_map:
attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"]
else:
attrs["_XlaScope"] = attr_value_pb2.AttrValue(
s=("function_%s" % func_name).encode())
# pylint: enable=protected-access
if kwargs:
raise ValueError("Unknown keyword arguments: %s" % kwargs.keys())
return attrs
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:26,代码来源:function.py
注:本文中的tensorflow.python.framework.ops.get_default_graph函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论