本文整理汇总了Python中tensorflow.python.eager.context.in_graph_mode函数的典型用法代码示例。如果您正苦于以下问题:Python in_graph_mode函数的具体用法?Python in_graph_mode怎么用?Python in_graph_mode使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了in_graph_mode函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testAddWeight
def testAddWeight(self):
layer = base_layers.Layer(name='my_layer')
# Test basic variable creation.
variable = layer.add_variable(
'my_var', [2, 2], initializer=init_ops.zeros_initializer())
self.assertEqual(variable.name, 'my_layer/my_var:0')
self.assertListEqual(layer.variables, [variable])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [])
if context.in_graph_mode():
self.assertListEqual(
layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Test non-trainable variable creation.
# layer.add_variable should work even outside `build` and `call`.
variable_2 = layer.add_variable(
'non_trainable_var', [2, 2],
initializer=init_ops.zeros_initializer(),
trainable=False)
self.assertListEqual(layer.variables, [variable, variable_2])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [variable_2])
if context.in_graph_mode():
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
variable = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
开发者ID:DjangoPeng,项目名称:tensorflow,代码行数:35,代码来源:base_test.py
示例2: testAddVariable
def testAddVariable(self):
obj = NonLayerCheckpointable()
with self.assertRaisesRegexp(ValueError, "do not specify shape"):
checkpointable_utils.add_variable(
obj, name="shape_specified_twice", shape=[], initializer=1)
constant_initializer = checkpointable_utils.add_variable(
obj, name="constant_initializer", initializer=1)
with variable_scope.variable_scope("some_variable_scope"):
ones_initializer = checkpointable_utils.add_variable(
obj,
name="ones_initializer",
shape=[2],
initializer=init_ops.ones_initializer(dtype=dtypes.float32))
bare_initializer = checkpointable_utils.add_variable(
obj,
name="bare_initializer",
shape=[2, 2],
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
# Even in graph mode, there are no naming conflicts between objects, only
# naming conflicts within an object.
other_duplicate = resource_variable_ops.ResourceVariable(
name="duplicate", initial_value=1.)
duplicate = checkpointable_utils.add_variable(
obj, name="duplicate", shape=[])
with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"):
checkpointable_utils.add_variable(obj, name="duplicate", shape=[])
if context.in_graph_mode():
self.evaluate(variables.global_variables_initializer())
self.assertEqual("constant_initializer:0", constant_initializer.name)
self.assertEqual(1, self.evaluate(constant_initializer))
self.assertEqual("some_variable_scope/ones_initializer:0",
ones_initializer.name)
self.assertAllEqual([1, 1], self.evaluate(ones_initializer))
self.assertAllEqual([[0., 0.],
[0., 0.]], self.evaluate(bare_initializer))
self.assertEqual("a_variable:0", obj.a_variable.name)
self.assertEqual("duplicate:0", other_duplicate.name)
if context.in_graph_mode():
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
else:
# When executing eagerly, there's no uniquification of variable names. The
# checkpoint name will be the same.
self.assertEqual("duplicate:0", duplicate.name)
named_variables, _ = checkpointable_utils._serialize_object_graph(obj)
expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
"bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
"constant_initializer/.ATTRIBUTES/VARIABLE_VALUE",
"duplicate/.ATTRIBUTES/VARIABLE_VALUE",
"ones_initializer/.ATTRIBUTES/VARIABLE_VALUE",
)
six.assertCountEqual(
self, expected_checkpoint_names, named_variables.keys())
开发者ID:dananjayamahesh,项目名称:tensorflow,代码行数:58,代码来源:checkpointable_utils_test.py
示例3: testDeferredSlotRestoration
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
root = checkpointable.Checkpointable()
root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.)
optimizer = CheckpointableAdam(0.1)
if context.in_graph_mode():
train_op = optimizer.minimize(root.var)
self.evaluate(variables.global_variables_initializer())
self.evaluate(train_op)
else:
optimizer.minimize(root.var.read_value)
self.evaluate(state_ops.assign(root.var, 12.))
no_slots_path = checkpointable_utils.Saver(root).save(
os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
14.))
slots_path = checkpointable_utils.Saver(root).save(
os.path.join(checkpoint_directory, "with_slots"))
new_root = checkpointable.Checkpointable()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
slot_status = checkpointable_utils.Saver(new_root).restore(slots_path)
no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
new_root.var = checkpointable_utils.add_variable(
new_root, name="var", shape=[])
no_slot_status.assert_consumed()
no_slot_status.run_restore_ops()
self.assertEqual(12., self.evaluate(new_root.var))
new_root.optimizer = CheckpointableAdam(0.1)
with self.assertRaisesRegexp(AssertionError, "beta1_power"):
slot_status.assert_consumed()
self.assertEqual(12., self.evaluate(new_root.var))
if context.in_eager_mode():
# Slot variables are only created with restoring initializers when
# executing eagerly.
self.assertEqual(14., self.evaluate(
new_root.optimizer.get_slot(name="m", var=new_root.var)))
else:
self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
None)
if context.in_graph_mode():
train_op = new_root.optimizer.minimize(new_root.var)
# The slot variable now exists; restore() didn't create it, but we should
# now have a restore op for it.
slot_status.run_restore_ops()
self.assertEqual(14., self.evaluate(
new_root.optimizer.get_slot(name="m", var=new_root.var)))
self.evaluate(train_op)
else:
new_root.optimizer.minimize(new_root.var.read_value)
slot_status.assert_consumed()
开发者ID:keithc61,项目名称:tensorflow,代码行数:57,代码来源:checkpointable_utils_test.py
示例4: testActivation
def testActivation(self):
dense = core_layers.Dense(2, activation=nn_ops.relu, name='dense1')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
if context.in_graph_mode():
self.assertEqual(outputs.op.name, 'dense1/Relu')
dense = core_layers.Dense(2, name='dense2')
inputs = random_ops.random_uniform((5, 3), seed=1)
outputs = dense(inputs)
if context.in_graph_mode():
self.assertEqual(outputs.op.name, 'dense2/BiasAdd')
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:12,代码来源:core_test.py
示例5: test_variable_reuse_exception_nested
def test_variable_reuse_exception_nested(self):
with test_util.IsolateTest(), session.Session():
first_container_variable = resource_variable_ops.ResourceVariable(
name="first_container_variable",
initial_value=1)
if context.in_graph_mode():
self.evaluate([variables.global_variables_initializer()])
with test_util.IsolateTest(), session.Session():
if context.in_graph_mode():
with self.assertRaises(RuntimeError):
self.evaluate(first_container_variable.read_value())
else:
with self.assertRaises(ValueError):
first_container_variable.read_value()
开发者ID:SylChan,项目名称:tensorflow,代码行数:14,代码来源:test_util_test.py
示例6: test_name_scopes_for_variable_scopes
def test_name_scopes_for_variable_scopes(self):
# Test that name scopes are not unnecessarily uniquified (but are
# still uniquified when necessary).
def linear_module(x, output_size):
w = variable_scope.get_variable(
"w", shape=[x.get_shape()[1], output_size],
initializer=init_ops.zeros_initializer())
b = variable_scope.get_variable(
"b", shape=[output_size],
initializer=init_ops.zeros_initializer())
return (math_ops.matmul(x, w) + b), w
def make_linear_module(output_size, name):
return template.make_template(
name,
linear_module,
output_size=output_size,
create_scope_now_=True)
inputs = array_ops.ones((3, 4))
linear1 = make_linear_module(output_size=2, name="foo")
outputs_a, w1 = linear1(inputs)
outputs_b, _ = linear1(inputs)
self.assertEquals("foo", linear1.variable_scope.name)
self.assertEquals("foo/w:0", w1.name)
if context.in_graph_mode():
self.assertEquals("foo/add:0", outputs_a.name,
"First application of template should get "
"same name scope as variables.")
self.assertEquals("foo_1/add:0", outputs_b.name,
"Second application of template should get "
"a freshly uniquified name scope.")
linear2 = make_linear_module(output_size=2, name="foo")
outputs_c, w2 = linear2(inputs)
outputs_d, _ = linear2(inputs)
self.assertEquals("foo_1", linear2.variable_scope.name,
"New template gets a freshly uniquified variable scope "
"because 'foo' is already taken.")
self.assertEquals("foo_1/w:0", w2.name)
if context.in_graph_mode():
self.assertEquals("foo_1_1/add:0", outputs_c.name,
"First application of template would get "
"same name scope as variables, but 'foo_1' is already "
"a name scope.")
self.assertEquals("foo_1_2/add:0", outputs_d.name,
"Second application of template should also get "
"a freshly uniquified name scope.")
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:49,代码来源:template_test.py
示例7: __call__
def __call__(self, *args):
"""Executes the passed function in eager mode."""
tensor_inputs = [
x for x in nest.flatten(args)
if isinstance(x, ops.Tensor)
]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
if not self._has_backprop:
self._compute_backprop()
return self._backprop_call(tensor_inputs)
if context.in_graph_mode():
g = ops.get_default_graph()
if self._fdef.name not in g._functions: # pylint: disable=protected-access
g._add_function(self._fdef) # pylint: disable=protected-access
signature = self._fdef.definition.signature
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in args],
[dtypes.DType(x.type) for x in signature.output_arg],
op_def=signature,
name="FunctionCall",
compute_shapes=False)
result = op.outputs
for i, s in enumerate(self._output_shapes):
result[i].set_shape(s)
else:
result = execute.execute(
str(self._func_name),
num_outputs=self._num_outputs,
inputs=tensor_inputs + self._extra_inputs)
return self._build_call_outputs(self._returns, result)
开发者ID:allanbian1017,项目名称:tensorflow,代码行数:34,代码来源:function.py
示例8: test_no_sharing
def test_no_sharing(self):
with test_util.IsolateTest(), session.Session():
first_container_variable = resource_variable_ops.ResourceVariable(
name="same_name",
initial_value=1)
if context.in_graph_mode():
self.evaluate([variables.global_variables_initializer()])
with test_util.IsolateTest(), session.Session():
second_container_variable = resource_variable_ops.ResourceVariable(
name="same_name",
initial_value=2)
if context.in_graph_mode():
self.evaluate([variables.global_variables_initializer()])
self.assertEqual(
2, self.evaluate(second_container_variable.read_value()))
self.assertEqual(1, self.evaluate(first_container_variable.read_value()))
开发者ID:SylChan,项目名称:tensorflow,代码行数:16,代码来源:test_util_test.py
示例9: _init_from_proto
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""
# Note that init_from_proto is currently not supported in Eager mode.
assert context.in_graph_mode()
self._in_graph_mode = True
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
raise ValueError("Trying to restore Variable as ResourceVariable.")
# Create from variable_def.
g = ops.get_default_graph()
self._handle = g.as_graph_element(
ops.prepend_name_scope(
variable_def.variable_name, import_scope=import_scope))
self._handle_device = self._handle.device
self._handle_name = self._handle.name
self._initializer_op = g.as_graph_element(
ops.prepend_name_scope(
variable_def.initializer_name, import_scope=import_scope))
if variable_def.snapshot_name:
self._cached_value = g.as_graph_element(
ops.prepend_name_scope(
variable_def.snapshot_name, import_scope=import_scope))
else:
self._cached_value = None
if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = variables.Variable.SaveSliceInfo(
save_slice_info_def=variable_def.save_slice_info_def)
else:
self._save_slice_info = None
self._caching_device = None
self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype"))
self._graph_element = self.value()
self._constraint = None
开发者ID:1000sprites,项目名称:tensorflow,代码行数:34,代码来源:resource_variable_ops.py
示例10: new_func
def new_func(*args, **kwargs):
"""Deprecation wrapper."""
# TODO(apassos) figure out a way to have reasonable performance with
# deprecation warnings and eager mode.
if context.in_graph_mode() and _PRINT_DEPRECATION_WARNINGS:
invalid_args = []
named_args = tf_inspect.getcallargs(func, *args, **kwargs)
for arg_name, spec in iter(deprecated_positions.items()):
if (spec.position < len(args) and
not (spec.has_ok_value and
_same_value(named_args[arg_name], spec.ok_value))):
invalid_args.append(arg_name)
if is_varargs_deprecated and len(args) > len(arg_spec.args):
invalid_args.append(arg_spec.varargs)
if is_kwargs_deprecated and kwargs:
invalid_args.append(arg_spec.keywords)
for arg_name in deprecated_arg_names:
if (arg_name in kwargs and
not (deprecated_positions[arg_name].has_ok_value and
_same_value(named_args[arg_name],
deprecated_positions[arg_name].ok_value))):
invalid_args.append(arg_name)
for arg_name in invalid_args:
if (func, arg_name) not in _PRINTED_WARNING:
if warn_once:
_PRINTED_WARNING[(func, arg_name)] = True
logging.warning(
'From %s: calling %s (from %s) with %s is deprecated and will '
'be removed %s.\nInstructions for updating:\n%s',
_call_location(), decorator_utils.get_qualified_name(func),
func.__module__, arg_name,
'in a future version' if date is None else ('after %s' % date),
instructions)
return func(*args, **kwargs)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:34,代码来源:deprecation.py
示例11: testRandomSeed
def testRandomSeed(self):
test_cases = [
# Each test case is a tuple with input to get_seed:
# (input_graph_seed, input_op_seed)
# and output from get_seed:
# (output_graph_seed, output_op_seed)
((None, None), (None, None)),
((None, 1), (random_seed.DEFAULT_GRAPH_SEED, 1)),
((1, 1), (1, 1)),
((0, 0), (0, 2**31 - 1)), # Avoid nondeterministic (0, 0) output
((2**31 - 1, 0), (0, 2**31 - 1)), # Don't wrap to (0, 0) either
((0, 2**31 - 1), (0, 2**31 - 1)), # Wrapping for the other argument
]
if context.in_graph_mode():
# 0 will be the default_graph._lastid.
test_cases.append(((1, None), (1, 0)))
else:
# operation seed is random number generated based on global seed.
# it's not tested due to possibility of platform or version difference.
pass
for tc in test_cases:
tinput, toutput = tc[0], tc[1]
random_seed.set_random_seed(tinput[0])
g_seed, op_seed = random_seed.get_seed(tinput[1])
msg = 'test_case = {0}, got {1}, want {2}'.format(tinput,
(g_seed, op_seed),
toutput)
self.assertEqual((g_seed, op_seed), toutput, msg=msg)
random_seed.set_random_seed(None)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:29,代码来源:random_seed_test.py
示例12: __init__
def __init__(self, handle, dtype, handle_device, # pylint: disable=super-init-not-called
shape, in_graph_mode, deleter, parent_op):
# We do not call super init on purpose.
self._trainable = False
self._save_slice_info = None
self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
self._in_graph_mode = in_graph_mode
self._handle = handle
self._handle_device = handle_device
self._shape = shape
self._initial_value = None
if isinstance(self._handle, ops.EagerTensor):
self._handle_name = ""
else:
self._handle_name = self._handle.name
self._dtype = dtype
self._constraint = None
self._cached_value = None
self._is_initialized_op = None
self._initializer_op = None
self._parent_op = parent_op
if context.in_graph_mode():
self._graph_element = self.read_value()
else:
self._graph_element = None
self._handle_deleter = deleter
开发者ID:keithc61,项目名称:tensorflow,代码行数:26,代码来源:resource_variable_ops.py
示例13: testMaskingSingleInput
def testMaskingSingleInput(self):
class MaskedLayer(base_layers.Layer):
def call(self, inputs, mask=None):
if mask is not None:
return inputs * mask
return inputs
def compute_mask(self, inputs, mask=None):
return array_ops.ones_like(inputs)
if context.in_graph_mode():
x = base_layers.Input(shape=(32,))
y = MaskedLayer()(x) # pylint: disable=not-callable
network = base_layers.Network(x, y)
# test callability on Input
x_2 = base_layers.Input(shape=(32,))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
# test callability on regular tensor
x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32))
y_2 = network(x_2)
self.assertEqual(y_2.get_shape().as_list(), [None, 32])
else:
a = constant_op.constant([2] * 32)
mask = constant_op.constant([0, 1] * 16)
a._keras_mask = mask
b = MaskedLayer().apply(a)
self.assertTrue(hasattr(b, '_keras_mask'))
self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)),
self.evaluate(getattr(b, '_keras_mask')))
self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b))
开发者ID:keveman,项目名称:tensorflow,代码行数:35,代码来源:base_test.py
示例14: _delay_checks
def _delay_checks(self):
"""Context manager for combining checks depending on tensor evaluations.
Each call to Session.run has some overhead, and this overhead can easily
account for the majority of the time spent in tests that call Session.run
(or Tensor.eval) many times.
This context manager provides a mechanism for registering callback functions
and associated tensors. When the context is exited, all of the tensors
associated with all of the registrations are evaluated with a single call to
Session.run, and then each registered callback function is called with the
values of its associated tensors.
Yields:
A function `add_check(check, *args, **kwargs)` where `check` is the
callback function to be invoked, and `*args` and `**kwargs` specify the
associated Tensors. When in EAGER mode, check is executed in add_check,
otherwise, it's delayed after the context.
"""
checks = []
def add_check(check, *args, **kwargs):
if context.in_eager_mode():
args_val, kwargs_val = self.evaluate([args, kwargs])
check(*args_val, **kwargs_val)
else:
checks.append((check, args, kwargs))
yield add_check
if context.in_graph_mode():
all_values = self.evaluate([[args, kwargs] for _, args, kwargs in checks])
for (check, _, _), (args, kwargs) in zip(checks, all_values):
check(*args, **kwargs)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:33,代码来源:atrous_convolution_test.py
示例15: __init__
def __init__(self, root_checkpointable):
"""Configure saving.
Args:
root_checkpointable: The root of the object graph to save/restore. This
object and all of its dependencies are saved in the checkpoint. When
restoring, objects are matched and restored starting from this root.
"""
# Allow passing in a weak reference to avoid reference cycles when
# `Checkpointable` objects save themselves.
self._root_checkpointable_ref = root_checkpointable
if context.in_graph_mode():
self._file_prefix_placeholder = constant_op.constant("model")
else:
self._file_prefix_placeholder = None
# Op caching for save
self._object_graph_feed_tensor = None
self._last_save_object_graph = None
self._last_save_saver = None
# Op caching for restore
self._object_graph_restore_tensor = None
self._last_restore_object_graph = None
self._last_restore_checkpoint = None
开发者ID:hhu-luqi,项目名称:tensorflow,代码行数:25,代码来源:checkpointable_utils.py
示例16: testAgnosticUsage
def testAgnosticUsage(self):
"""Graph/eager agnostic usage."""
# Does create garbage when executing eagerly due to ops.Graph() creation.
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default(), self.test_session(
graph=ops.get_default_graph()):
network = MyNetwork()
optimizer = adam.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, network=network,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
optimizer.minimize,
functools.partial(network, input_value),
global_step=root.global_step)
if context.in_graph_mode():
train_fn = functools.partial(self.evaluate, train_fn())
status.initialize_or_restore()
for _ in range(num_training_steps):
train_fn()
root.save(file_prefix=checkpoint_prefix)
self.assertEqual((training_continuation + 1) * num_training_steps,
self.evaluate(root.global_step))
self.assertEqual(training_continuation + 1,
self.evaluate(root.save_counter))
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:31,代码来源:checkpointable_utils_test.py
示例17: call
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
output = super(BatchNormalization, self).call(inputs, training=training)
if context.in_graph_mode() and training is K.learning_phase():
output._uses_learning_phase = True # pylint: disable=protected-access
return output
开发者ID:QiangCai,项目名称:tensorflow,代码行数:7,代码来源:normalization.py
示例18: _create_non_slot_variable
def _create_non_slot_variable(self, initial_value, name, colocate_with):
"""Add an extra variable, not associated with a slot."""
if context.in_graph_mode():
graph = colocate_with.graph
else:
graph = None
key = (name, graph)
v = self._non_slot_dict.get(key, None)
if v is None:
with ops.colocate_with(colocate_with):
def _variable_getter(name, shape, dtype, initializer):
del shape, dtype # not used, but there for compatibility
return variable_scope.variable(
name=name, initial_value=initializer, trainable=False)
initial_value = ops.convert_to_tensor(initial_value)
v = self.add_variable(
name=name,
shape=initial_value.get_shape(),
initializer=initial_value,
getter=_variable_getter)
self._non_slot_dict[key] = v
return v
开发者ID:japrogramer,项目名称:tensorflow,代码行数:26,代码来源:checkpointable_test.py
示例19: testInputSpecNdimCheck
def testInputSpecNdimCheck(self):
class CustomerLayer(base_layers.Layer):
def __init__(self):
super(CustomerLayer, self).__init__()
self.input_spec = base_layers.InputSpec(ndim=2)
def call(self, inputs):
return inputs
if context.in_graph_mode():
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'requires a defined rank'):
layer.apply(array_ops.placeholder('int32'))
layer = CustomerLayer()
with self.assertRaisesRegexp(ValueError, r'expected ndim=2'):
layer.apply(constant_op.constant([1]))
# Note that we re-create the layer since in Eager mode, input spec checks
# only happen on first call.
# Works
layer = CustomerLayer()
layer.apply(constant_op.constant([[1], [2]]))
开发者ID:keveman,项目名称:tensorflow,代码行数:25,代码来源:base_test.py
示例20: split
def split(self, value, lengths, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArraySplit",
[self._handle, value, lengths]):
value = ops.convert_to_tensor(value, name="value")
with self._maybe_colocate_with(value):
lengths_64 = math_ops.to_int64(lengths)
if self._infer_shape and context.in_graph_mode():
clengths = tensor_util.constant_value(lengths_64)
if value.shape.dims is not None:
if clengths is not None and clengths.max() == clengths.min():
self._merge_element_shape(
tensor_shape.TensorShape([clengths[0]]).concatenate(
value.shape[1:]))
flow_out = gen_data_flow_ops._tensor_array_split_v3(
handle=self._handle,
value=value,
lengths=lengths_64,
flow_in=self._flow,
name=name)
ta = TensorArray(
dtype=self._dtype, handle=self._handle, flow=flow_out,
colocate_with_first_write_call=self._colocate_with_first_write_call)
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
return ta
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:27,代码来源:tensor_array_ops.py
注:本文中的tensorflow.python.eager.context.in_graph_mode函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论