本文整理汇总了Python中tensorflow.python.data.util.nest.flatten函数的典型用法代码示例。如果您正苦于以下问题:Python flatten函数的具体用法?Python flatten怎么用?Python flatten使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了flatten函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testFlattenAndPack
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
nest.pack_sequence_as(structure, flat), (("a", "b"), "c",
("d", "e", ("f", "g"), "h")))
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(nest.flatten(structure), flat)
restructured_from_flat = nest.pack_sequence_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], nest.flatten(5))
self.assertEqual([np.array([5])], nest.flatten(np.array([5])))
self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
self.assertEqual(
np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))
with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
nest.pack_sequence_as("scalar", [4, 5])
with self.assertRaisesRegexp(TypeError, "flat_sequence"):
nest.pack_sequence_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:33,代码来源:nest_test.py
示例2: _check_shape
def _check_shape(*elements):
flatten_tensors = nest.flatten(elements)
flatten_shapes = nest.flatten(expected_shapes)
checked_tensors = [with_shape(shape, tensor)
for shape, tensor in zip(flatten_shapes,
flatten_tensors)]
return nest.pack_sequence_as(elements, checked_tensors)
开发者ID:bikong2,项目名称:tensorflow,代码行数:7,代码来源:batching.py
示例3: from_value
def from_value(value):
"""Returns an `Optional` that wraps the given value.
Args:
value: A nested structure of `tf.Tensor` and/or `tf.SparseTensor` objects.
Returns:
An `Optional` that wraps `value`.
"""
# TODO(b/110122868): Consolidate this destructuring logic with the
# similar code in `Dataset.from_tensors()`.
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
value = nest.pack_sequence_as(value, [
sparse_tensor_lib.SparseTensor.from_value(t)
if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
t, name="component_%d" % i)
for i, t in enumerate(nest.flatten(value))
])
encoded_value = nest.flatten(sparse.serialize_sparse_tensors(value))
output_classes = sparse.get_classes(value)
output_shapes = nest.pack_sequence_as(
value, [t.get_shape() for t in nest.flatten(value)])
output_types = nest.pack_sequence_as(
value, [t.dtype for t in nest.flatten(value)])
return _OptionalImpl(
gen_dataset_ops.optional_from_value(encoded_value, name=scope),
output_shapes, output_types, output_classes)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:30,代码来源:optional_ops.py
示例4: assertShapesEqual
def assertShapesEqual(self, a, b):
for a, b in zip(nest.flatten(a), nest.flatten(b)):
self.assertEqual(a.ndims, b.ndims)
if a.ndims is None:
continue
for c, d in zip(a.as_list(), b.as_list()):
self.assertEqual(c, d)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:7,代码来源:sparse_test.py
示例5: _apply_fn
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
tensor_batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
flattened = _RestructuredDataset(
dataset,
tuple(nest.flatten(dataset.output_types)),
output_classes=tuple(nest.flatten(dataset.output_classes)))
def _predicate(*xs):
"""Return `True` if this element is a full batch."""
# Extract the dynamic batch size from the first component of the flattened
# batched element.
first_component = xs[0]
first_component_batch_size = array_ops.shape(
first_component, out_type=dtypes.int64)[0]
return math_ops.equal(first_component_batch_size, tensor_batch_size)
filtered = flattened.filter(_predicate)
maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
def _set_first_dimension(shape):
return shape.merge_with(
tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
known_shapes = nest.map_structure(_set_first_dimension,
dataset.output_shapes)
return _RestructuredDataset(
filtered,
dataset.output_types,
known_shapes,
output_classes=dataset.output_classes)
开发者ID:Kongsea,项目名称:tensorflow,代码行数:35,代码来源:batching.py
示例6: __init__
def __init__(self, dataset):
"""Creates a new iterator over the given dataset.
For example:
```python
dataset = tf.contrib.data.Dataset.range(4)
for x in Iterator(dataset):
print(x)
```
Args:
dataset: A `tf.contrib.data.Dataset` object.
Raises:
RuntimeError: When invoked without eager execution enabled.
"""
if not context.in_eager_mode():
raise RuntimeError(
"{} objects only make sense when eager execution is enabled".format(
type(self)))
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
self._output_types = dataset.output_types
self._flat_output_types = nest.flatten(dataset.output_types)
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
self._resource = gen_dataset_ops.iterator(
container="",
shared_name=_iterator_shared_name(),
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:31,代码来源:datasets.py
示例7: materialize
def materialize(self, shared_name=None, container=None):
"""Materialize creates a MaterializedIndexedDataset.
IndexedDatasets can be combined through operations such as TBD. Therefore,
they are only materialized when absolutely required.
Args:
shared_name: a string for the shared name to use for the resource.
container: a string for the container to store the resource.
Returns:
A MaterializedIndexedDataset.
"""
if container is None:
container = ""
if shared_name is None:
shared_name = ""
materialized_resource = (
ged_ops.experimental_materialized_index_dataset_handle(
container=container,
shared_name=shared_name,
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_types(self.output_shapes,
self.output_classes))))
with ops.colocate_with(materialized_resource):
materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
self.output_shapes)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:33,代码来源:indexed_dataset_ops.py
示例8: make_initializer
def make_initializer(self, dataset, name=None):
"""Returns a `tf.Operation` that initializes this iterator on `dataset`.
Args:
dataset: A `Dataset` with compatible structure to this iterator.
name: (Optional.) A name for the created operation.
Returns:
A `tf.Operation` that can be run to initialize this iterator on the given
`dataset`.
Raises:
TypeError: If `dataset` and this iterator do not have a compatible
element structure.
"""
with ops.name_scope(name, "make_initializer") as name:
nest.assert_same_structure(self._output_types, dataset.output_types)
nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
for iterator_dtype, dataset_dtype in zip(
nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
if iterator_dtype != dataset_dtype:
raise TypeError(
"Expected output types %r but got dataset with output types %r." %
(self._output_types, dataset.output_types))
for iterator_shape, dataset_shape in zip(
nest.flatten(self._output_shapes),
nest.flatten(dataset.output_shapes)):
if not iterator_shape.is_compatible_with(dataset_shape):
raise TypeError("Expected output shapes compatible with %r but got "
"dataset with output shapes %r." %
(self._output_shapes, dataset.output_shapes))
with ops.colocate_with(self._iterator_resource):
return gen_dataset_ops.make_iterator(
dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:34,代码来源:iterator.py
示例9: testSerializeDeserialize
def testSerializeDeserialize(self):
test_cases = (
(),
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
sparse_tensor.SparseTensor(
indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
(sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
(sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
((), sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
)
for expected in test_cases:
classes = sparse.get_classes(expected)
shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
classes)
types = nest.map_structure(lambda _: dtypes.int32, classes)
actual = sparse.deserialize_sparse_tensors(
sparse.serialize_sparse_tensors(expected), types, shapes,
sparse.get_classes(expected))
nest.assert_same_structure(expected, actual)
for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
self.assertSparseValuesEqual(a, e)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:27,代码来源:sparse_test.py
示例10: get_next_as_optional
def get_next_as_optional(iterator):
"""Returns an `Optional` that contains the next value from the iterator.
If `iterator` has reached the end of the sequence, the returned `Optional`
will have no value.
Args:
iterator: A `tf.data.Iterator` object.
Returns:
An `Optional` object representing the next value from the iterator (if it
has one) or no value.
"""
# pylint: disable=protected-access
return optional_ops._OptionalImpl(
gen_dataset_ops.iterator_get_next_as_optional(
iterator._iterator_resource,
output_types=nest.flatten(
sparse.as_dense_types(iterator.output_types,
iterator.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(iterator.output_shapes,
iterator.output_classes))),
structure.Structure._from_legacy_structure(iterator.output_types,
iterator.output_shapes,
iterator.output_classes))
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:26,代码来源:iterator_ops.py
示例11: get_next
def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s containing the next element.
Args:
name: (Optional.) A name for the created operation.
Returns:
A nested structure of `tf.Tensor` objects.
"""
self._get_next_call_count += 1
if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE)
return sparse.deserialize_sparse_tensors(
nest.pack_sequence_as(self._output_types,
gen_dataset_ops.iterator_get_next(
self._iterator_resource,
output_types=nest.flatten(
sparse.as_dense_types(
self._output_types,
self._output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(
self._output_shapes,
self._output_classes)),
name=name)), self._output_types,
self._output_shapes, self._output_classes)
开发者ID:modkzs,项目名称:tensorflow,代码行数:27,代码来源:iterator_ops.py
示例12: testRoundTripConversion
def testRoundTripConversion(self, value_fn):
value = value_fn()
s = structure.Structure.from_value(value)
def maybe_stack_ta(v):
if isinstance(v, tensor_array_ops.TensorArray):
return v.stack()
else:
return v
before = self.evaluate(maybe_stack_ta(value))
after = self.evaluate(
maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value))))
flat_before = nest.flatten(before)
flat_after = nest.flatten(after)
for b, a in zip(flat_before, flat_after):
if isinstance(b, sparse_tensor.SparseTensorValue):
self.assertAllEqual(b.indices, a.indices)
self.assertAllEqual(b.values, a.values)
self.assertAllEqual(b.dense_shape, a.dense_shape)
elif isinstance(
b,
(ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
self.assertRaggedEqual(b, a)
else:
self.assertAllEqual(b, a)
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:structure_test.py
示例13: get_next
def get_next(self, name=None):
"""See `tf.data.Iterator.get_next`."""
self._get_next_call_count += 1
if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
flat_result = []
# TODO(priyag): This will fail if the input size (typically number of
# batches) is not divisible by number of devices.
# How do we handle that more gracefully / let the user know?
for buffer_resource in self._buffering_resources:
flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
buffer_resource,
output_types=data_nest.flatten(sparse.as_dense_types(
self.output_types, self.output_classes)), name=name)
ret = sparse.deserialize_sparse_tensors(
data_nest.pack_sequence_as(self.output_types, flat_ret),
self.output_types, self.output_shapes, self.output_classes)
for tensor, shape in zip(
data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
if isinstance(tensor, ops.Tensor):
tensor.set_shape(shape)
flat_result.append(ret)
return nest.pack_sequence_as(self._devices, flat_result)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:27,代码来源:prefetching_ops_v2.py
示例14: assertDatasetsEqual
def assertDatasetsEqual(self, dataset1, dataset2):
"""Checks that datasets are equal. Supports both graph and eager mode."""
self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with(
dataset_ops.get_structure(dataset2)))
self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with(
dataset_ops.get_structure(dataset1)))
flattened_types = nest.flatten(
dataset_ops.get_legacy_output_types(dataset1))
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
while True:
try:
op1 = self.evaluate(next1())
except errors.OutOfRangeError:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next2())
break
op2 = self.evaluate(next2())
op1 = nest.flatten(op1)
op2 = nest.flatten(op2)
assert len(op1) == len(op2)
for i in range(len(op1)):
if sparse_tensor.is_sparse(op1[i]):
self.assertSparseValuesEqual(op1[i], op2[i])
elif flattened_types[i] == dtypes.string:
self.assertAllEqual(op1[i], op2[i])
else:
self.assertAllClose(op1[i], op2[i])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:test_base.py
示例15: assertDatasetsEqual
def assertDatasetsEqual(self, dataset1, dataset2):
"""Checks that datasets are equal. Supports both graph and eager mode."""
self.assertEqual(dataset1.output_types, dataset2.output_types)
self.assertEqual(dataset1.output_classes, dataset2.output_classes)
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
while True:
try:
op1 = self.evaluate(next1())
except errors.OutOfRangeError:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next2())
break
op2 = self.evaluate(next2())
op1 = nest.flatten(op1)
op2 = nest.flatten(op2)
assert len(op1) == len(op2)
for i in range(len(op1)):
if isinstance(
op1[i],
(sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
self.assertSparseValuesEqual(op1[i], op2[i])
else:
self.assertAllEqual(op1[i], op2[i])
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:26,代码来源:test_base.py
示例16: testToBatchedTensorList
def testToBatchedTensorList(self, value_fn, element_0_fn):
batched_value = value_fn()
s = structure.Structure.from_value(batched_value)
batched_tensor_list = s._to_batched_tensor_list(batched_value)
# The batch dimension is 2 for all of the test cases.
# NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
# tensors in which we store sparse tensors.
for t in batched_tensor_list:
if t.dtype != dtypes.variant:
self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))
# Test that the 0th element from the unbatched tensor is equal to the
# expected value.
expected_element_0 = self.evaluate(element_0_fn())
unbatched_s = s._unbatch()
actual_element_0 = unbatched_s._from_tensor_list(
[t[0] for t in batched_tensor_list])
for expected, actual in zip(
nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
if sparse_tensor.is_sparse(expected):
self.assertSparseValuesEqual(expected, actual)
elif ragged_tensor.is_ragged(expected):
self.assertRaggedEqual(expected, actual)
else:
self.assertAllEqual(expected, actual)
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:structure_test.py
示例17: _as_variant_tensor
def _as_variant_tensor(self):
return gen_dataset_ops.ignore_errors_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:7,代码来源:error_ops.py
示例18: testIndefiniteRepeatShapeInference
def testIndefiniteRepeatShapeInference(self):
dataset = self.make_batch_feature(
filenames=self.test_filenames[0], num_epochs=None, batch_size=32)
for shape, clazz in zip(nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_classes)):
if issubclass(clazz, ops.Tensor):
self.assertEqual(32, shape[0])
开发者ID:mrlittlepig,项目名称:tensorflow,代码行数:7,代码来源:reader_dataset_ops_test.py
示例19: _as_variant_tensor
def _as_variant_tensor(self):
return self._op_function(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._tag,
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:8,代码来源:stats_ops.py
示例20: output_shapes
def output_shapes(self):
ret = self._data_inputs[0].output_shapes
for data_input in self._data_inputs[1:]:
ret = nest.pack_sequence_as(ret, [
ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
nest.flatten(ret), nest.flatten(data_input.output_shapes))
])
return ret
开发者ID:AnishShah,项目名称:tensorflow,代码行数:8,代码来源:interleave_ops.py
注:本文中的tensorflow.python.data.util.nest.flatten函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论