本文整理汇总了Python中tensorflow.python.data.util.nest.map_structure函数的典型用法代码示例。如果您正苦于以下问题:Python map_structure函数的具体用法?Python map_structure怎么用?Python map_structure使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了map_structure函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testSerializeDeserialize
def testSerializeDeserialize(self):
test_cases = (
(),
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]),
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
sparse_tensor.SparseTensor(
indices=[[0, 0], [3, 4]], values=[1, -1], dense_shape=[4, 5]),
(sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
(sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1]), ()),
((), sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1])),
)
for expected in test_cases:
classes = sparse.get_classes(expected)
shapes = nest.map_structure(lambda _: tensor_shape.TensorShape(None),
classes)
types = nest.map_structure(lambda _: dtypes.int32, classes)
actual = sparse.deserialize_sparse_tensors(
sparse.serialize_sparse_tensors(expected), types, shapes,
sparse.get_classes(expected))
nest.assert_same_structure(expected, actual)
for a, e in zip(nest.flatten(actual), nest.flatten(expected)):
self.assertSparseValuesEqual(a, e)
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:27,代码来源:sparse_test.py
示例2: from_string_handle
def from_string_handle(string_handle, output_types, output_shapes=None):
"""Creates a new, uninitialized `Iterator` based on the given handle.
This method allows you to define a "feedable" iterator where you can choose
between concrete iterators by feeding a value in a @{tf.Session.run} call.
In that case, `string_handle` would a @{tf.placeholder}, and you would feed
it with the value of @{tf.data.Iterator.string_handle} in each step.
For example, if you had two iterators that marked the current position in
a training dataset and a test dataset, you could choose which to use in
each step as follows:
```python
train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())
test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_iterator.output_types)
next_element = iterator.get_next()
loss = f(next_element)
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
```
Args:
string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
to a handle produced by the `Iterator.string_handle()` method.
output_types: A nested structure of `tf.DType` (or `tf.data.SparseType`)
objects corresponding to each `tf.Tensor` (or `tf.SparseTensor`)
component of an element of this dataset.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset. If
omitted, each component will have an unconstrainted shape.
Returns:
An `Iterator`.
"""
output_types = nest.map_structure(dtypes.as_dtype, output_types)
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
nest.assert_same_structure(output_types, output_shapes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
string_handle,
output_types=nest.flatten(sparse.unwrap_sparse_types(output_types)),
output_shapes=nest.flatten(output_shapes))
return Iterator(iterator_resource, None, output_types, output_shapes)
开发者ID:SylChan,项目名称:tensorflow,代码行数:57,代码来源:iterator_ops.py
示例3: _apply_fn
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
tensor_batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
flattened = _RestructuredDataset(
dataset,
tuple(nest.flatten(dataset.output_types)),
output_classes=tuple(nest.flatten(dataset.output_classes)))
def _predicate(*xs):
"""Return `True` if this element is a full batch."""
# Extract the dynamic batch size from the first component of the flattened
# batched element.
first_component = xs[0]
first_component_batch_size = array_ops.shape(
first_component, out_type=dtypes.int64)[0]
return math_ops.equal(first_component_batch_size, tensor_batch_size)
filtered = flattened.filter(_predicate)
maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size)
def _set_first_dimension(shape):
return shape.merge_with(
tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:]))
known_shapes = nest.map_structure(_set_first_dimension,
dataset.output_shapes)
return _RestructuredDataset(
filtered,
dataset.output_types,
known_shapes,
output_classes=dataset.output_classes)
开发者ID:Kongsea,项目名称:tensorflow,代码行数:35,代码来源:batching.py
示例4: __init__
def __init__(self, input_dataset, num_workers):
self._input_dataset = input_dataset
def recalculate_output_shapes(output_shapes):
"""Recalculates the output_shapes after dividing it by num_workers."""
if len(output_shapes) < 1:
raise ValueError("Input shape should have at least one dimension.")
if (tensor_shape.dimension_value(output_shapes[0]) and
tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
raise errors.InvalidArgumentError(
None, None,
"First dim of input shape: %d is not divisible by num_workers: %d" %
(output_shapes[0], num_workers))
output_dims = [d for d in output_shapes.dims]
output_dims[0] = output_dims[0] // num_workers
return tensor_shape.TensorShape(output_dims)
input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)
self._structure = structure.convert_legacy_structure(
input_types, output_shapes, input_classes)
variant_tensor = ged_ops.experimental_rebatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
num_workers=num_workers,
**dataset_ops.flat_structure(self))
super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:distribute.py
示例5: __init__
def __init__(self, dataset, output_types, output_shapes=None):
"""Creates a new dataset with the given output types and shapes.
The given `dataset` must have a structure that is convertible:
* `dataset.output_types` must be the same as `output_types` module nesting.
* Each shape in `dataset.output_shapes` must be compatible with each shape
in `output_shapes` (if given).
Note: This helper permits "unsafe casts" for shapes, equivalent to using
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
Args:
dataset: A `Dataset` object.
output_types: A nested structure of `tf.DType` objects.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
If omitted, the shapes will be inherited from `dataset`.
Raises:
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
super(_RestructuredDataset, self).__init__()
self._dataset = dataset
# Validate that the types are compatible.
output_types = nest.map_structure(dtypes.as_dtype, output_types)
flat_original_types = nest.flatten(dataset.output_types)
flat_new_types = nest.flatten(output_types)
if flat_original_types != flat_new_types:
raise ValueError(
"Dataset with output types %r cannot be restructured to have output "
"types %r" % (dataset.output_types, output_types))
self._output_types = output_types
if output_shapes is None:
# Inherit shapes from the original `dataset`.
self._output_shapes = nest.pack_sequence_as(output_types,
nest.flatten(
dataset.output_shapes))
else:
# Validate that the shapes are compatible.
nest.assert_same_structure(output_types, output_shapes)
flat_original_shapes = nest.flatten(dataset.output_shapes)
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
for original_shape, new_shape in zip(flat_original_shapes,
flat_new_shapes):
if not original_shape.is_compatible_with(new_shape):
raise ValueError(
"Dataset with output shapes %r cannot be restructured to have "
"incompatible output shapes %r" % (dataset.output_shapes,
output_shapes))
self._output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:55,代码来源:batching.py
示例6: __init__
def __init__(self, variant_tensor, output_shapes, output_types,
output_classes):
# TODO(b/110122868): Consolidate the structure validation logic with the
# similar logic in `Iterator.from_structure()` and
# `Dataset.from_generator()`.
output_types = nest.map_structure(dtypes.as_dtype, output_types)
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
nest.assert_same_structure(output_types, output_shapes)
nest.assert_same_structure(output_types, output_classes)
self._variant_tensor = variant_tensor
self._output_shapes = output_shapes
self._output_types = output_types
self._output_classes = output_classes
开发者ID:AnishShah,项目名称:tensorflow,代码行数:14,代码来源:optional_ops.py
示例7: __init__
def __init__(self, driver_name, data_source_name, query, output_types):
"""Creates a `SqlDataset`.
`SqlDataset` allows a user to read data from the result set of a SQL query.
For example:
```python
tf.enable_eager_execution()
dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
"SELECT name, age FROM people",
(tf.string, tf.int32))
# Prints the rows of the result set of the above query.
for element in dataset:
print(element)
```
Args:
driver_name: A 0-D `tf.string` tensor containing the database type.
Currently, the only supported value is 'sqlite'.
data_source_name: A 0-D `tf.string` tensor containing a connection string
to connect to the database.
query: A 0-D `tf.string` tensor containing the SQL query to execute.
output_types: A tuple of `tf.DType` objects representing the types of the
columns returned by `query`.
"""
self._driver_name = ops.convert_to_tensor(
driver_name, dtype=dtypes.string, name="driver_name")
self._data_source_name = ops.convert_to_tensor(
data_source_name, dtype=dtypes.string, name="data_source_name")
self._query = ops.convert_to_tensor(
query, dtype=dtypes.string, name="query")
self._structure = structure.NestedStructure(
nest.map_structure(
lambda dtype: structure.TensorStructure(dtype, []), output_types))
variant_tensor = gen_experimental_dataset_ops.experimental_sql_dataset(
self._driver_name, self._data_source_name, self._query,
**dataset_ops.flat_structure(self))
super(SqlDatasetV2, self).__init__(variant_tensor)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:39,代码来源:readers.py
示例8: __init__
def __init__(self, input_dataset):
"""See `unbatch()` for more details."""
flat_shapes = nest.flatten(input_dataset.output_shapes)
if any(s.ndims == 0 for s in flat_shapes):
raise ValueError("Cannot unbatch an input with scalar components.")
known_batch_dim = tensor_shape.Dimension(None)
for s in flat_shapes:
try:
known_batch_dim = known_batch_dim.merge_with(s[0])
except ValueError:
raise ValueError("Cannot unbatch an input whose components have "
"different batch sizes.")
self._input_dataset = input_dataset
self._structure = structure.convert_legacy_structure(
input_dataset.output_types,
nest.map_structure(lambda s: s[1:], input_dataset.output_shapes),
input_dataset.output_classes)
variant_tensor = ged_ops.experimental_unbatch_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
**dataset_ops.flat_structure(self))
super(_UnbatchDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:ziky90,项目名称:tensorflow,代码行数:23,代码来源:batching.py
示例9: __init__
def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
"""Initialize `PrependFromQueueAndPaddedBatchDataset`."""
super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
if sparse.any_sparse(input_dataset.output_classes):
raise TypeError(
"Batching of padded sparse tensors is not currently supported")
self._input_dataset = input_dataset
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name="batch_size")
# pylint: disable=protected-access
if padded_shapes is None:
self._padded_shapes = nest.map_structure(
dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
else:
self._padded_shapes = nest.map_structure_up_to(
input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
padded_shapes)
padding_values = (
padding_values if padding_values is not None else
dataset_ops._default_padding(input_dataset))
self._padding_values = nest.map_structure_up_to(
input_dataset.output_shapes, dataset_ops._padding_value_to_tensor,
padding_values, input_dataset.output_types)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:23,代码来源:tensor_queue_dataset.py
示例10: output_shapes
def output_shapes(self):
return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
self._output_types)
开发者ID:jfreedman0,项目名称:tensorflow,代码行数:3,代码来源:readers.py
示例11: output_classes
def output_classes(self):
return nest.map_structure(lambda _: ops.Tensor, self._output_types)
开发者ID:jfreedman0,项目名称:tensorflow,代码行数:2,代码来源:readers.py
示例12: testMapStructure
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1)
nest.assert_same_structure(structure1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7],
nest.flatten(structure1_plus1))
structure1_plus_structure2 = nest.map_structure(
lambda x, y: x + y, structure1, structure2)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2)
self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4))
with self.assertRaisesRegexp(TypeError, "callable"):
nest.map_structure("bad", structure1_plus1)
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegexp(TypeError, "same sequence type"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), {"a": (3, 4), "b": 5})
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
with self.assertRaisesRegexp(ValueError, "same nested structure"):
nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)),
check_types=False)
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, foo="a")
with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"):
nest.map_structure(lambda x: None, structure1, check_types=False, foo="a")
开发者ID:abidrahmank,项目名称:tensorflow,代码行数:39,代码来源:nest_test.py
示例13: _unbatch
def _unbatch(self):
return NestedStructure(nest.map_structure(
lambda s: s._unbatch(), self._nested_structure))
开发者ID:aritratony,项目名称:tensorflow,代码行数:3,代码来源:structure.py
示例14: output_shapes
def output_shapes(self):
return nest.map_structure(lambda s: s[1:],
self._input_dataset.output_shapes)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:3,代码来源:batching.py
示例15: output_shapes
def output_shapes(self):
# First output is a variant representing the Queue
return (tensor_shape.vector(None),
nest.map_structure(self._as_batch_shape, self._padded_shapes))
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:4,代码来源:tensor_queue_dataset.py
示例16: from_structure
def from_structure(output_types,
output_shapes=None,
shared_name=None,
output_classes=None):
"""Creates a new, uninitialized `Iterator` with the given structure.
This iterator-constructing method can be used to create an iterator that
is reusable with many different datasets.
The returned iterator is not bound to a particular dataset, and it has
no `initializer`. To initialize the iterator, run the operation returned by
`Iterator.make_initializer(dataset)`.
The following is an example
```python
iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)
dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)
# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())
# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
# Initialize the iterator to `dataset_range`
sess.run(range_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
# Initialize the iterator to `dataset_evens`
sess.run(evens_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
```
Args:
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element of this dataset.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset. If
omitted, each component will have an unconstrainted shape.
shared_name: (Optional.) If non-empty, this iterator will be shared under
the given name across multiple sessions that share the same devices
(e.g. when using a remote server).
output_classes: (Optional.) A nested structure of Python `type` objects
corresponding to each component of an element of this iterator. If
omitted, each component is assumed to be of type `tf.Tensor`.
Returns:
An `Iterator`.
Raises:
TypeError: If the structures of `output_shapes` and `output_types` are
not the same.
"""
output_types = nest.map_structure(dtypes.as_dtype, output_types)
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
if output_classes is None:
output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
nest.assert_same_structure(output_types, output_shapes)
if shared_name is None:
shared_name = ""
iterator_resource = gen_dataset_ops.iterator(
container="",
shared_name=shared_name,
output_types=nest.flatten(
sparse.as_dense_types(output_types, output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(output_shapes, output_classes)))
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
开发者ID:modkzs,项目名称:tensorflow,代码行数:90,代码来源:iterator_ops.py
示例17: __init__
def __init__(self,
dataset,
output_types,
output_shapes=None,
output_classes=None,
allow_unsafe_cast=False):
"""Creates a new dataset with the given output types and shapes.
The given `dataset` must have a structure that is convertible:
* `dataset.output_types` must be the same as `output_types` module nesting.
* Each shape in `dataset.output_shapes` must be compatible with each shape
in `output_shapes` (if given).
Note: This helper permits "unsafe casts" for shapes, equivalent to using
`tf.Tensor.set_shape()` where domain-specific knowledge is available.
Args:
dataset: A `Dataset` object.
output_types: A nested structure of `tf.DType` objects.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
If omitted, the shapes will be inherited from `dataset`.
output_classes: (Optional.) A nested structure of class types.
If omitted, the class types will be inherited from `dataset`.
allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
reported output types and shapes of the restructured dataset, e.g. to
switch a sparse tensor represented as `tf.variant` to its user-visible
type and shape.
Raises:
ValueError: If either `output_types` or `output_shapes` is not compatible
with the structure of `dataset`.
"""
self._input_dataset = dataset
input_types = dataset_ops.get_legacy_output_types(dataset)
if not allow_unsafe_cast:
# Validate that the types are compatible.
output_types = nest.map_structure(dtypes.as_dtype, output_types)
flat_original_types = nest.flatten(input_types)
flat_new_types = nest.flatten(output_types)
if flat_original_types != flat_new_types:
raise ValueError(
"Dataset with output types %r cannot be restructured to have "
"output types %r" %
(dataset_ops.get_legacy_output_types(dataset), output_types))
input_shapes = dataset_ops.get_legacy_output_shapes(dataset)
if output_shapes is None:
# Inherit shapes from the original `dataset`.
output_shapes = nest.pack_sequence_as(
output_types, nest.flatten(input_shapes))
else:
if not allow_unsafe_cast:
# Validate that the shapes are compatible.
nest.assert_same_structure(output_types, output_shapes)
flat_original_shapes = nest.flatten(input_shapes)
flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
for original_shape, new_shape in zip(flat_original_shapes,
flat_new_shapes):
if not original_shape.is_compatible_with(new_shape):
raise ValueError(
"Dataset with output shapes %r cannot be restructured to have "
"incompatible output shapes %r" % (input_shapes,
output_shapes))
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
input_classes = dataset_ops.get_legacy_output_classes(dataset)
if output_classes is None:
# Inherit class types from the original `dataset`.
output_classes = nest.pack_sequence_as(
output_types, nest.flatten(input_classes))
self._structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
variant_tensor = self._input_dataset._variant_tensor # pylint: disable=protected-access
super(_RestructuredDataset, self).__init__(dataset, variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:78,代码来源:batching.py
示例18: from_string_handle
def from_string_handle(string_handle,
output_types,
output_shapes=None,
output_classes=None):
"""Creates a new, uninitialized `Iterator` based on the given handle.
This method allows you to define a "feedable" iterator where you can choose
between concrete iterators by feeding a value in a `tf.Session.run` call.
In that case, `string_handle` would be a `tf.placeholder`, and you would
feed it with the value of `tf.data.Iterator.string_handle` in each step.
For example, if you had two iterators that marked the current position in
a training dataset and a test dataset, you could choose which to use in
each step as follows:
```python
train_iterator = tf.data.Dataset(...).make_one_shot_iterator()
train_iterator_handle = sess.run(train_iterator.string_handle())
test_iterator = tf.data.Dataset(...).make_one_shot_iterator()
test_iterator_handle = sess.run(test_iterator.string_handle())
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_iterator.output_types)
next_element = iterator.get_next()
loss = f(next_element)
train_loss = sess.run(loss, feed_dict={handle: train_iterator_handle})
test_loss = sess.run(loss, feed_dict={handle: test_iterator_handle})
```
Args:
string_handle: A scalar `tf.Tensor` of type `tf.string` that evaluates
to a handle produced by the `Iterator.string_handle()` method.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element of this dataset.
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset. If
omitted, each component will have an unconstrainted shape.
output_classes: (Optional.) A nested structure of Python `type` objects
corresponding to each component of an element of this iterator. If
omitted, each component is assumed to be of type `tf.Tensor`.
Returns:
An `Iterator`.
"""
output_types = nest.map_structure(dtypes.as_dtype, output_types)
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
if output_classes is None:
output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
nest.assert_same_structure(output_types, output_shapes)
output_structure = structure_lib.convert_legacy_structure(
output_types, output_shapes, output_classes)
string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string)
# pylint: disable=protected-access
if compat.forward_compatible(2018, 8, 3):
if _device_stack_is_empty():
with ops.device("/cpu:0"):
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
output_types=output_structure._flat_types,
output_shapes=output_structure._flat_shapes)
else:
iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
string_handle,
output_types=output_structure._flat_types,
output_shapes=output_structure._flat_shapes)
else:
iterator_resource = gen_dataset_ops.iterator_from_string_handle(
string_handle,
output_types=output_structure._flat_types,
output_shapes=output_structure._flat_shapes)
# pylint: enable=protected-access
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:82,代码来源:iterator_ops.py
示例19: _batch
def _batch(self, batch_size):
return NestedStructure(nest.map_structure(
lambda s: s._batch(batch_size), self._nested_structure))
开发者ID:aritratony,项目名称:tensorflow,代码行数:3,代码来源:structure.py
示例20: from_generator
def from_generator(generator, output_types, output_shapes=None):
"""Creates a `Dataset` whose elements are generated by `generator`.
The `generator` argument must be a callable object that returns
an object that support the `iter()` protocol (e.g. a generator function).
The elements generated by `generator` must be compatible with the given
`output_types` and (optional) `output_shapes` arguments.
For example:
```python
import itertools
def gen():
for i in itertools.count(1):
yield (i, [1] * i)
ds = Dataset.from_generator(
gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
value = ds.make_one_shot_iterator().get_next()
sess.run(value) # (1, array([1]))
sess.run(value) # (2, array([1, 1]))
```
Args:
generator: A callable object that takes no arguments and returns an
object that supports the `iter()` protocol.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element yielded by `generator`.
output_shapes: (Optional.) A nested structure of `tf.TensorShape`
objects corresponding to each component of an element yielded by
`generator`.
Returns:
A `Dataset`.
"""
if not callable(generator):
raise TypeError("`generator` must be callable.")
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
flattened_types = nest.flatten(output_types)
flattened_shapes = nest.flatten(output_shapes)
generator_state = dataset_ops.Dataset._GeneratorState(generator)
def get_iterator_id_map_fn(unused_dummy):
"""Creates a unique `iterator_id` for each pass over the dataset.
The "iterator_id" disambiguates between multiple concurrently
existing iterators.
Args:
unused_dummy: Ignored value.
Returns:
A `tf.int64` tensor whose value uniquely identifies an iterator in
`generator_state`.
"""
return script_ops.py_func(
generator_state.get_next_id, [], dtypes.int64, stateful=True)
def generator_map_fn(iterator_id_t):
"""Generates the next element from iterator with ID `iterator_id_t`.
We map this function across an infinite repetition of the
`iterator_id_t`, and raise `StopIteration` to terminate the iteration.
Args:
iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
the iterator in `generator_state` from which to generate an element.
Returns:
A nested structure of tensors representing an element from the iterator.
"""
def generator_py_func(iterator_id):
"""A `py_func` that will be called to invoke the iterator."""
try:
values = next(generator_state.get_iterator(iterator_id))
except StopIteration:
generator_state.iterator_completed(iterator_id)
raise StopIteration("Iteration finished.")
# Use the same _convert function from the py_func() implementation to
# convert the returned values to arrays early, so that we can inspect
# their values.
# pylint: disable=protected-access
ret_arrays = [
script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
for ret, dtype in zip(nest.flatten_up_to(output_types, values),
flattened_types)
]
# pylint: enable=protected-access
#.........这里部分代码省略.........
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:101,代码来源:dataset_ops.py
注:本文中的tensorflow.python.data.util.nest.map_structure函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论