本文整理汇总了Python中tensorflow.python.util.nest.assert_same_structure函数的典型用法代码示例。如果您正苦于以下问题:Python assert_same_structure函数的具体用法?Python assert_same_structure怎么用?Python assert_same_structure使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_same_structure函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _concrete_function_callable_with
def _concrete_function_callable_with(function, inputs, allow_conversion):
"""Returns whether concrete `function` can be called with `inputs`."""
expected_structure = function.graph.structured_input_signature
try:
flatten_inputs = nest.flatten_up_to(expected_structure, inputs)
except (TypeError, ValueError):
return False
try:
# Verify that no input elements were dropped during flattening.
repacked = nest.pack_sequence_as(expected_structure, flatten_inputs)
# TODO(b/129422719): Namedtuple subclasses re-created through
# saved_model.load don't compare equal in type to the original in
# assert_same_structure. Fix that and we can take out check_types=False
# here.
nest.assert_same_structure(inputs, repacked, check_types=False)
except (TypeError, ValueError):
return False
for arg, expected in zip(flatten_inputs, nest.flatten(expected_structure)):
if isinstance(expected, tensor_spec.TensorSpec):
if allow_conversion:
arg = _try_convert_to_tensor_spec(arg, dtype_hint=expected.dtype)
if not _is_tensor(arg) and not isinstance(arg, tensor_spec.TensorSpec):
return False
if arg.dtype != expected.dtype:
return False
if not expected.shape.is_compatible_with(arg.shape):
return False
else:
if arg != expected:
return False
return True
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:32,代码来源:function_deserialization.py
示例2: wrapped_body
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
Args:
loop_counter: Loop counter which needs to be incremented in the body.
*args: List of args
Returns:
A list of tensors the same length as args.
"""
# Capture the tensors already captured in cond_graph so that they appear
# in the same order in body_graph.external_captures.
for t in cond_graph.external_captures:
ops.get_default_graph().capture(t)
# Convert the flow variables in `args` to TensorArrays. `args` should
# already have the same structure as `orig_loop_vars` but currently there
# is no nest.zip so we call `_pack_sequence_as` which flattens both
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
# and packs it into the structure of `orig_loop_vars`.
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
if not nest.is_sequence(outputs):
outputs = [outputs]
# Compare the structure of input and output of body converting the
# top-level tuples to list to be compatible with legacy while_loop.
nest.assert_same_structure(list(outputs), list(orig_loop_vars))
outputs = _tensor_array_to_flow(outputs)
# TODO(srbs): Update lowering code to create _Enter nodes with
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1] + list(outputs)
开发者ID:ziky90,项目名称:tensorflow,代码行数:32,代码来源:while_v2.py
示例3: __init__
def __init__(self, initial_state, mask=None, name="trainable_initial_state"):
"""Constructs the Module that introduces a trainable state in the graph.
It receives an initial state that will be used as the initial values for the
trainable variables that the module contains, and optionally a mask that
indicates the parts of the initial state that should be learnable.
Args:
initial_state: tensor or arbitrarily nested iterables of tensors.
mask: optional boolean mask. It should have the same nested structure as
the given initial_state.
name: module name.
Raises:
TypeError: if mask is not a list of booleans or None.
"""
super(TrainableInitialState, self).__init__(name=name)
# Since python 2.7, DeprecationWarning is ignored by default.
# Turn on the warning:
warnings.simplefilter("always", DeprecationWarning)
warnings.warn("Use the trainable flag in initial_state instead.",
DeprecationWarning, stacklevel=2)
if mask is not None:
flat_mask = nest.flatten(mask)
if not all([isinstance(m, bool) for m in flat_mask]):
raise TypeError("Mask should be None or a list of boolean values.")
nest.assert_same_structure(initial_state, mask)
self._mask = mask
self._initial_state = initial_state
开发者ID:TianjiPang,项目名称:sonnet,代码行数:32,代码来源:rnn_core.py
示例4: _check_same_outputs
def _check_same_outputs(true_graph, false_graph):
"""Raises an error if true_graph and false_graph have different outputs."""
def error(error_detail):
raise TypeError(
"true_fn and false_fn arguments to tf.cond must have the same number, "
"type, and overall structure of return values.\n"
"\n"
"true_fn output: %s\n"
"false_fn output: %s\n"
"\n"
"Error details:\n"
"%s" % (true_graph.structured_outputs, false_graph.structured_outputs,
error_detail))
try:
nest.assert_same_structure(true_graph.structured_outputs,
false_graph.structured_outputs,
expand_composites=True)
except (ValueError, TypeError) as e:
error(str(e))
assert len(true_graph.outputs) == len(false_graph.outputs)
for true_out, false_out in zip(true_graph.outputs, false_graph.outputs):
if true_out.dtype != false_out.dtype:
error("%s and %s have different types" % (true_out, false_out))
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:26,代码来源:cond_v2.py
示例5: wrapped_body
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
Args:
loop_counter: Loop counter which needs to be incremented in the body.
*args: List of args
args[:len_orig_loop_vars] - Args for the original loop body.
args[len_orig_loop_vars:] - External captures of cond. These get
passed through as is.
Returns:
A list of tensors the same length as args.
"""
# Convert the flow variables in `args` to TensorArrays. `args` should
# already have the same structure as `orig_loop_vars` but currently there
# is no nest.zip so we call `_pack_sequence_as` which flattens both
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
# and packs it into the structure of `orig_loop_vars`.
outputs = body(
*_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
if not nest.is_sequence(outputs):
outputs = [outputs]
# Compare the structure of input and output of body converting the
# top-level tuples to list to be compatible with legacy while_loop.
nest.assert_same_structure(list(outputs), list(orig_loop_vars))
outputs = _tensor_array_to_flow(outputs)
# Return the external_captures of cond_graph as is, i.e., treat them as
# loop invariants.
# TODO(srbs): Update lowering code to create _Enter nodes with
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1] + list(outputs) + list(
args[len_orig_loop_vars:])
开发者ID:aeverall,项目名称:tensorflow,代码行数:34,代码来源:while_v2.py
示例6: testMapStructure
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:29,代码来源:nest_test.py
示例7: compute
def compute(i, a_flat, tas):
"""The loop body of scan.
Args:
i: the loop counter.
a_flat: the accumulator value(s), flattened.
tas: the output accumulator TensorArray(s), flattened.
Returns:
[i + 1, a_flat, tas]: the updated counter + new accumulator values +
updated TensorArrays
Raises:
TypeError: if initializer and fn() output structure do not match
ValueType: if initializer and fn() output lengths do not match
"""
packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
packed_a = output_pack(a_flat)
a_out = fn(packed_a, packed_elems)
nest.assert_same_structure(
elems if initializer is None else initializer, a_out)
flat_a_out = output_flatten(a_out)
tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
if reverse:
next_i = i - 1
else:
next_i = i + 1
return (next_i, flat_a_out, tas)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:28,代码来源:functional_ops.py
示例8: body
def body(time, elements_finished, current_input, emit_ta, state, loop_state):
"""Internal while loop body for raw_rnn.
Args:
time: time scalar.
elements_finished: batch-size vector.
current_input: possibly nested tuple of input tensors.
emit_ta: possibly nested tuple of output TensorArrays.
state: possibly nested tuple of state tensors.
loop_state: possibly nested tuple of loop state tensors.
Returns:
Tuple having the same size as Args but with updated values.
"""
(next_output, cell_state) = cell(current_input, state)
nest.assert_same_structure(state, cell_state)
nest.assert_same_structure(cell.output_size, next_output)
next_time = time + 1
(next_finished, next_input, next_state, emit_output, next_loop_state) = loop_fn(
next_time, next_output, cell_state, loop_state
)
nest.assert_same_structure(state, next_state)
nest.assert_same_structure(current_input, next_input)
nest.assert_same_structure(emit_ta, emit_output)
# If loop_fn returns None for next_loop_state, just reuse the
# previous one.
loop_state = loop_state if next_loop_state is None else next_loop_state
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
current_flat = nest.flatten(current)
candidate_flat = nest.flatten(candidate)
# pylint: disable=g-long-lambda,cell-var-from-loop
result_flat = [
_on_device(
lambda: array_ops.where(elements_finished, current_i, candidate_i), device=candidate_i.op.device
)
for (current_i, candidate_i) in zip(current_flat, candidate_flat)
]
# pylint: enable=g-long-lambda,cell-var-from-loop
return nest.pack_sequence_as(structure=current, flat_sequence=result_flat)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
emit_output_flat = nest.flatten(emit_output)
emit_ta_flat = nest.flatten(emit_ta)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
emit_ta_flat = [ta.write(time, emit) for (ta, emit) in zip(emit_ta_flat, emit_output_flat)]
emit_ta = nest.pack_sequence_as(structure=emit_structure, flat_sequence=emit_ta_flat)
return (next_time, elements_finished, next_input, emit_ta, next_state, loop_state)
开发者ID:ygoverdhan,项目名称:tensorflow,代码行数:59,代码来源:rnn.py
示例9: _is_flat
def _is_flat(sequence):
sequence_flat = nest.flatten(sequence)
try:
nest.assert_same_structure(sequence_flat, sequence)
return True
except ValueError:
return False
except TypeError:
return False
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:9,代码来源:save.py
示例10: testNestAssertSameStructureCompositeMismatch
def testNestAssertSameStructureCompositeMismatch(self,
s1,
s2,
error=ValueError):
# s1 and s2 have the same structure if expand_composites=False; but
# different structures if expand_composites=True.
nest.assert_same_structure(s1, s2, expand_composites=False)
nest.assert_shallow_structure(s1, s2, expand_composites=False)
with self.assertRaises(error): # pylint: disable=g-error-prone-assert-raises
nest.assert_same_structure(s1, s2, expand_composites=True)
开发者ID:aritratony,项目名称:tensorflow,代码行数:10,代码来源:composite_tensor_test.py
示例11: insert
def insert(self, keys, values):
nest.assert_same_structure(self._hash_tables, values)
# Avoid race conditions by requiring that all inputs are computed before any
# inserts happen (an issue if one key's update relies on another's value).
values_flat = [array_ops.identity(value) for value in nest.flatten(values)]
with ops.control_dependencies(values_flat):
insert_ops = [hash_table.insert(keys, value)
for hash_table, value
in zip(nest.flatten(self._hash_tables),
values_flat)]
return control_flow_ops.group(*insert_ops)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:11,代码来源:math_utils.py
示例12: _assert_correct_outputs
def _assert_correct_outputs(self, initial_state_):
nest.assert_same_structure(initial_state_, self.decoder_cell.state_size)
nest.assert_same_structure(initial_state_, self.encoder_outputs.final_state)
encoder_state_flat = nest.flatten(self.encoder_outputs.final_state)
with self.test_session() as sess:
encoder_state_flat_ = sess.run(encoder_state_flat)
initial_state_flat_ = nest.flatten(initial_state_)
for e_dec, e_enc in zip(initial_state_flat_, encoder_state_flat_):
np.testing.assert_array_equal(e_dec, e_enc)
开发者ID:AbhinavJain13,项目名称:seq2seq,代码行数:11,代码来源:bridges_test.py
示例13: run_and_report
def run_and_report(self, s1, s2, name):
burn_iter, test_iter = 100, 30000
for _ in xrange(burn_iter):
nest.assert_same_structure(s1, s2)
t0 = time.time()
for _ in xrange(test_iter):
nest.assert_same_structure(s1, s2)
t1 = time.time()
self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter,
name=name)
开发者ID:Huoxubeiyin,项目名称:tensorflow,代码行数:13,代码来源:nest_test.py
示例14: _maybe_copy_some_through
def _maybe_copy_some_through():
"""Run RNN step. Pass through either no or some past state."""
new_output, new_state = call_cell()
nest.assert_same_structure(state, new_state)
flat_new_state = nest.flatten(new_state)
flat_new_output = nest.flatten(new_output)
return control_flow_ops.cond(
# if t < min_seq_len: calculate and return everything
time < min_sequence_length, lambda: flat_new_output + flat_new_state,
# else copy some of it through
lambda: _copy_some_through(flat_new_output, flat_new_state))
开发者ID:giancds,项目名称:attentive_lm,代码行数:13,代码来源:rnn.py
示例15: body
def body(time, elements_finished, current_input,
emit_ta, state, loop_state):
"""Internal while loop body for raw_rnn.
Args:
time: time scalar.
elements_finished: batch-size vector.
current_input: possibly nested tuple of input tensors.
emit_ta: possibly nested tuple of output TensorArrays.
state: possibly nested tuple of state tensors.
loop_state: possibly nested tuple of loop state tensors.
Returns:
Tuple having the same size as Args but with updated values.
"""
(next_output, cell_state) = cell(current_input, state)
nest.assert_same_structure(state, cell_state)
nest.assert_same_structure(cell.output_size, next_output)
next_time = time + 1
(next_finished, next_input, next_state, emit_output,
next_loop_state) = loop_fn(
next_time, next_output, cell_state, loop_state)
nest.assert_same_structure(state, next_state)
nest.assert_same_structure(current_input, next_input)
nest.assert_same_structure(emit_ta, emit_output)
# If loop_fn returns None for next_loop_state, just reuse the
# previous one.
loop_state = loop_state if next_loop_state is None else next_loop_state
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
def copy_fn(cur_i, cand_i):
return _on_device(
lambda: array_ops.where(elements_finished, cur_i, cand_i),
device=cand_i.op.device)
return nest.map_structure(copy_fn, current, candidate)
emit_output = _copy_some_through(zero_emit, emit_output)
next_state = _copy_some_through(state, next_state)
emit_ta = nest.map_structure(
lambda ta, emit: ta.write(time, emit), emit_ta, emit_output)
elements_finished = math_ops.logical_or(elements_finished, next_finished)
return (next_time, elements_finished, next_input,
emit_ta, next_state, loop_state)
开发者ID:jzuern,项目名称:tensorflow,代码行数:51,代码来源:rnn.py
示例16: check_mutation
def check_mutation(n1, n2):
"""Check if two list of arguments are exactly the same."""
errmsg = ("Function to be traced should not modify structure of input "
"arguments. Check if your function has list and dictionary "
"operations that alter input arguments, "
"such as `list.pop`, `list.append`")
try:
nest.assert_same_structure(n1, n2)
except ValueError:
raise ValueError(errmsg)
for arg1, arg2 in zip(nest.flatten(n1), nest.flatten(n2)):
if arg1 is not arg2:
raise ValueError(errmsg)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:14,代码来源:func_graph.py
示例17: testInitialStateComputation
def testInitialStateComputation(self, tuple_state, mask):
if tuple_state:
initial_state = (tf.fill([BATCH_SIZE, 6], 2),
(tf.fill([BATCH_SIZE, 7], 3),
tf.fill([BATCH_SIZE, 8], 4)))
else:
initial_state = tf.fill([BATCH_SIZE, 9], 10)
trainable_state_module = snt.TrainableInitialState(initial_state, mask=mask)
trainable_state = trainable_state_module()
flat_trainable_state = nest.flatten(trainable_state)
nest.assert_same_structure(initial_state, trainable_state)
flat_initial_state = nest.flatten(initial_state)
if mask is not None:
flat_mask = nest.flatten(mask)
else:
flat_mask = (True,) * len(flat_initial_state)
self.evaluate(tf.global_variables_initializer())
# Check all variables are initialized correctly and return a state that
# has the same as it is provided.
for trainable_state, initial_state in zip(flat_trainable_state,
flat_initial_state):
self.assertAllEqual(
self.evaluate(trainable_state), self.evaluate(initial_state))
# Change the value of all the trainable variables to ones.
for variable in tf.trainable_variables():
self.evaluate(tf.assign(variable, tf.ones_like(variable)))
# In eager mode to re-evaluate the module we must re-connect it.
trainable_state = trainable_state_module()
flat_trainable_state = nest.flatten(trainable_state)
# Check that the values of the initial_states have changed if and only if
# they are trainable.
for trainable_state, initial_state, mask in zip(flat_trainable_state,
flat_initial_state,
flat_mask):
trainable_state_value = self.evaluate(trainable_state)
initial_state_value = self.evaluate(initial_state)
if mask:
expected_value = np.ones_like(initial_state_value)
else:
expected_value = initial_state_value
self.assertAllEqual(trainable_state_value, expected_value)
开发者ID:ccchang0111,项目名称:sonnet,代码行数:48,代码来源:rnn_core_test.py
示例18: test_convert_to_generator_like
def test_convert_to_generator_like(self, input_fn, inputs):
expected_batches = 5
data = input_fn(self, inputs, expected_batches)
# Dataset and Iterator not supported in Legacy Graph mode.
if (not context.executing_eagerly() and
isinstance(data, (dataset_ops.DatasetV2, iterator_ops.Iterator))):
return
generator, steps = training_generator.convert_to_generator_like(
data, batch_size=2, steps_per_epoch=expected_batches)
self.assertEqual(steps, expected_batches)
for _ in range(expected_batches):
outputs = next(generator)
nest.assert_same_structure(outputs, inputs)
开发者ID:aeverall,项目名称:tensorflow,代码行数:16,代码来源:training_generator_test.py
示例19: testNestAssertSameStructure
def testNestAssertSameStructure(self):
st1 = sparse_tensor.SparseTensor([[0]], [0], [100])
st2 = sparse_tensor.SparseTensor([[0, 3]], ['x'], [100, 100])
test = TestCompositeTensor(st1.indices, st1.values, st1.dense_shape)
nest.assert_same_structure(st1, st2, expand_composites=False)
nest.assert_same_structure(st1, st2, expand_composites=True)
nest.assert_same_structure(st1, test, expand_composites=False)
with self.assertRaises(TypeError):
nest.assert_same_structure(st1, test, expand_composites=True)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:9,代码来源:composite_tensor_test.py
示例20: testMapStructure
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, structure1, structure1_list)
nest.map_structure(lambda x, y: None, structure1, structure1_list,
check_types=False)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, foo="a")
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:46,代码来源:nest_test.py
注:本文中的tensorflow.python.util.nest.assert_same_structure函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论