• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python dataset_ops.get_legacy_output_types函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.data.ops.dataset_ops.get_legacy_output_types函数的典型用法代码示例。如果您正苦于以下问题:Python get_legacy_output_types函数的具体用法?Python get_legacy_output_types怎么用?Python get_legacy_output_types使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_legacy_output_types函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testNonSequenceNestedStructure

  def testNonSequenceNestedStructure(self):
    components = np.array([1, 2, 3], dtype=np.int64)

    dataset = dataset_ops.Dataset.from_tensors(components)
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

    dataset = dataset.filter(
        lambda x: math_ops.reduce_all(math_ops.equal(x, components)))
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

    dataset = dataset.map(lambda x: array_ops.stack([x, x]))
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([2, 3], dataset_ops.get_legacy_output_shapes(dataset))

    dataset = dataset.flat_map(
        lambda x: dataset_ops.Dataset.from_tensor_slices(x))
    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset))
    self.assertEqual([3], dataset_ops.get_legacy_output_shapes(dataset))

    get_next = self.getNext(dataset)
    self.assertEqual(dtypes.int64, get_next().dtype)
    self.assertEqual([3], get_next().shape)
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:from_tensors_test.py


示例2: testUnbatchScalarDataset

  def testUnbatchScalarDataset(self):
    data = tuple([math_ops.range(10) for _ in range(3)])
    data = dataset_ops.Dataset.from_tensor_slices(data)
    expected_types = (dtypes.int32,) * 3
    data = data.batch(2)
    self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
    data = data.apply(batching.unbatch())
    self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))

    self.assertDatasetProduces(data, [(i,) * 3 for i in range(10)])
开发者ID:aritratony,项目名称:tensorflow,代码行数:10,代码来源:unbatch_test.py


示例3: testUnbatchDatasetWithStrings

  def testUnbatchDatasetWithStrings(self):
    data = tuple([math_ops.range(10) for _ in range(3)])
    data = dataset_ops.Dataset.from_tensor_slices(data)
    data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
    expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
    data = data.batch(2)
    self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))
    data = data.apply(batching.unbatch())
    self.assertEqual(expected_types, dataset_ops.get_legacy_output_types(data))

    self.assertDatasetProduces(
        data, [(i, compat.as_bytes(str(i)), i) for i in range(10)])
开发者ID:aritratony,项目名称:tensorflow,代码行数:12,代码来源:unbatch_test.py


示例4: testNestedDict

 def testNestedDict(self):
   components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]}
   dataset = dataset_ops.Dataset.from_tensors(components)
   self.assertEqual(dtypes.int32,
                    dataset_ops.get_legacy_output_types(dataset)["a"]["aa"])
   self.assertEqual(dtypes.float32,
                    dataset_ops.get_legacy_output_types(dataset)["a"]["ab"])
   self.assertEqual(dtypes.int32,
                    dataset_ops.get_legacy_output_types(dataset)["b"])
   self.assertEqual([],
                    dataset_ops.get_legacy_output_shapes(dataset)["a"]["aa"])
   self.assertEqual([2],
                    dataset_ops.get_legacy_output_shapes(dataset)["a"]["ab"])
   self.assertEqual([3],
                    dataset_ops.get_legacy_output_shapes(dataset)["b"])
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:from_tensors_test.py


示例5: testUnbatchMultiElementTupleDataset

  def testUnbatchMultiElementTupleDataset(self):
    data = tuple([(math_ops.range(10 * i, 10 * i + 10),
                   array_ops.fill([10], "hi")) for i in range(3)])
    data = dataset_ops.Dataset.from_tensor_slices(data)
    expected_types = ((dtypes.int32, dtypes.string),) * 3
    data = data.batch(2)
    self.assertAllEqual(expected_types,
                        dataset_ops.get_legacy_output_types(data))
    data = data.apply(batching.unbatch())
    self.assertAllEqual(expected_types,
                        dataset_ops.get_legacy_output_types(data))

    self.assertDatasetProduces(
        data,
        [((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")) for i in range(10)])
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:unbatch_test.py


示例6: _apply_fn

  def _apply_fn(dataset):
    """Function from `Dataset` to `Dataset` that applies the transformation."""
    # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
    # are normalized to the rank-1 dense representation, so that the
    # sparse-oblivious unbatching logic will slice them
    # appropriately. This leads to a somewhat inefficient re-encoding step
    # for all SparseTensor components.
    # TODO(mrry): Consider optimizing this in future if it turns out to be
    # a bottleneck.
    def normalize(arg, *rest):
      # pylint: disable=protected-access
      if rest:
        return dataset._element_structure._to_batched_tensor_list((arg,) + rest)
      else:
        return dataset._element_structure._to_batched_tensor_list(arg)

    normalized_dataset = dataset.map(normalize)

    # NOTE(mrry): Our `map()` has lost information about the sparseness
    # of any SparseTensor components, so re-apply the structure of the
    # original dataset.
    restructured_dataset = _RestructuredDataset(
        normalized_dataset,
        dataset_ops.get_legacy_output_types(dataset),
        dataset_ops.get_legacy_output_shapes(dataset),
        dataset_ops.get_legacy_output_classes(dataset),
        allow_unsafe_cast=True)
    return _UnbatchDataset(restructured_dataset)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:28,代码来源:batching.py


示例7: __init__

  def __init__(self, input_dataset, num_workers):
    self._input_dataset = input_dataset

    def recalculate_output_shapes(output_shapes):
      """Recalculates the output_shapes after dividing it by num_workers."""
      if len(output_shapes) < 1:
        raise ValueError("Input shape should have at least one dimension.")
      if (tensor_shape.dimension_value(output_shapes[0]) and
          tensor_shape.dimension_value(output_shapes[0]) % num_workers != 0):
        raise errors.InvalidArgumentError(
            None, None,
            "First dim of input shape: %d is not divisible by num_workers: %d" %
            (output_shapes[0], num_workers))
      output_dims = [d for d in output_shapes.dims]
      output_dims[0] = output_dims[0] // num_workers
      return tensor_shape.TensorShape(output_dims)

    input_types = dataset_ops.get_legacy_output_types(self._input_dataset)
    input_shapes = dataset_ops.get_legacy_output_shapes(self._input_dataset)
    input_classes = dataset_ops.get_legacy_output_classes(self._input_dataset)
    output_shapes = nest.map_structure(recalculate_output_shapes, input_shapes)

    self._structure = structure.convert_legacy_structure(
        input_types, output_shapes, input_classes)
    variant_tensor = ged_ops.experimental_rebatch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        num_workers=num_workers,
        **dataset_ops.flat_structure(self))
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:distribute.py


示例8: assertDatasetsEqual

  def assertDatasetsEqual(self, dataset1, dataset2):
    """Checks that datasets are equal. Supports both graph and eager mode."""
    self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with(
        dataset_ops.get_structure(dataset2)))
    self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with(
        dataset_ops.get_structure(dataset1)))
    flattened_types = nest.flatten(
        dataset_ops.get_legacy_output_types(dataset1))

    next1 = self.getNext(dataset1)
    next2 = self.getNext(dataset2)

    while True:
      try:
        op1 = self.evaluate(next1())
      except errors.OutOfRangeError:
        with self.assertRaises(errors.OutOfRangeError):
          self.evaluate(next2())
        break
      op2 = self.evaluate(next2())

      op1 = nest.flatten(op1)
      op2 = nest.flatten(op2)
      assert len(op1) == len(op2)
      for i in range(len(op1)):
        if sparse_tensor.is_sparse(op1[i]):
          self.assertSparseValuesEqual(op1[i], op2[i])
        elif flattened_types[i] == dtypes.string:
          self.assertAllEqual(op1[i], op2[i])
        else:
          self.assertAllClose(op1[i], op2[i])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:test_base.py


示例9: _create_or_validate_filenames_dataset

def _create_or_validate_filenames_dataset(filenames):
  """Creates (or validates) a dataset of filenames.

  Args:
    filenames: Either a list or dataset of filenames. If it is a list, it is
      convert to a dataset. If it is a dataset, its type and shape is validated.

  Returns:
    A dataset of filenames.
  """
  if isinstance(filenames, dataset_ops.DatasetV2):
    if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
      raise TypeError(
          "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
    if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
        tensor_shape.scalar()):
      raise TypeError(
          "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
          "elements.")
  else:
    filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string)
    filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
    filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)

  return filenames
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:25,代码来源:readers.py


示例10: write

  def write(self, dataset, column_families, columns, timestamp=None):
    """Writes a dataset to the table.

    Args:
      dataset: A `tf.data.Dataset` to be written to this table. It must produce
        a list of number-of-columns+1 elements, all of which must be strings.
        The first value will be used as the row key, and subsequent values will
        be used as cell values for the corresponding columns from the
        corresponding column_families and columns entries.
      column_families: A `tf.Tensor` of `tf.string`s corresponding to the
        column names to store the dataset's elements into.
      columns: A `tf.Tensor` of `tf.string`s corresponding to the column names
        to store the dataset's elements into.
      timestamp: (Optional.) An int64 timestamp to write all the values at.
        Leave as None to use server-provided timestamps.

    Returns:
      A `tf.Operation` that can be run to perform the write.

    Raises:
      ValueError: If there are unexpected or incompatible types, or if the
        number of columns and column_families does not match the output of
        `dataset`.
    """
    if timestamp is None:
      timestamp = -1  # Bigtable server provided timestamp.
    for tensor_type in nest.flatten(
        dataset_ops.get_legacy_output_types(dataset)):
      if tensor_type != dtypes.string:
        raise ValueError("Not all elements of the dataset were `tf.string`")
    for shape in nest.flatten(dataset_ops.get_legacy_output_shapes(dataset)):
      if not shape.is_compatible_with(tensor_shape.scalar()):
        raise ValueError("Not all elements of the dataset were scalars")
    if len(column_families) != len(columns):
      raise ValueError("len(column_families) != len(columns)")
    if len(nest.flatten(
        dataset_ops.get_legacy_output_types(dataset))) != len(columns) + 1:
      raise ValueError("A column name must be specified for every component of "
                       "the dataset elements. (e.g.: len(columns) != "
                       "len(dataset.output_types))")
    return gen_bigtable_ops.dataset_to_bigtable(
        self._resource,
        dataset._variant_tensor,  # pylint: disable=protected-access
        column_families,
        columns,
        timestamp)
开发者ID:jackd,项目名称:tensorflow,代码行数:46,代码来源:bigtable_api.py


示例11: testKinesisDatasetTwoShards

  def testKinesisDatasetTwoShards(self):
    client = boto3.client('kinesis', region_name='us-east-1')

    # Setup the Kinesis with 2 shards.
    stream_name = "tf_kinesis_test_2"
    client.create_stream(StreamName=stream_name, ShardCount=2)
    # Wait until stream exists, default is 10 * 18 seconds.
    client.get_waiter('stream_exists').wait(StreamName=stream_name)

    for i in range(10):
      data = "D" + str(i)
      client.put_record(
          StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
    response = client.describe_stream(StreamName=stream_name)
    shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"]
    shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"]

    stream = array_ops.placeholder(dtypes.string, shape=[])
    shard = array_ops.placeholder(dtypes.string, shape=[])
    num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
    batch_size = array_ops.placeholder(dtypes.int64, shape=[])

    repeat_dataset = kinesis_dataset_ops.KinesisDataset(
        stream, shard, read_indefinitely=False).repeat(num_epochs)
    batch_dataset = repeat_dataset.batch(batch_size)

    iterator = iterator_ops.Iterator.from_structure(
        dataset_ops.get_legacy_output_types(batch_dataset))
    init_op = iterator.make_initializer(repeat_dataset)
    init_batch_op = iterator.make_initializer(batch_dataset)
    get_next = iterator.get_next()

    data = []
    with self.cached_session() as sess:
      # Basic test: read from shard 0 of stream 2.
      sess.run(
          init_op, feed_dict={
              stream: stream_name, shard: shard_id_0, num_epochs: 1})
      with self.assertRaises(errors.OutOfRangeError):
        # Use range(11) to guarantee the OutOfRangeError.
        for i in range(11):
          data.append(sess.run(get_next))

      # Basic test: read from shard 1 of stream 2.
      sess.run(
          init_op, feed_dict={
              stream: stream_name, shard: shard_id_1, num_epochs: 1})
      with self.assertRaises(errors.OutOfRangeError):
        # Use range(11) to guarantee the OutOfRangeError.
        for i in range(11):
          data.append(sess.run(get_next))

    data.sort()
    self.assertEqual(data, ["D" + str(i) for i in range(10)])

    client.delete_stream(StreamName=stream_name)
    # Wait until stream deleted, default is 10 * 18 seconds.
    client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:58,代码来源:kinesis_test.py


示例12: _apply_fn

 def _apply_fn(dataset):
   output_shapes = _merge_output_shapes(
       dataset_ops.get_legacy_output_shapes(dataset), expected_shapes)
   # pylint: disable=protected-access
   return batching._RestructuredDataset(
       dataset.map(_check_shape),
       dataset_ops.get_legacy_output_types(dataset),
       output_shapes=output_shapes,
       output_classes=dataset_ops.get_legacy_output_classes(dataset))
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:9,代码来源:batching.py


示例13: testFromTensorSlicesWithDict

  def testFromTensorSlicesWithDict(self):
    components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
    dataset = dataset_ops.Dataset.from_tensor_slices(components)
    get_next = self.getNext(dataset)

    self.assertEqual(dtypes.int32,
                     dataset_ops.get_legacy_output_types(dataset)["foo"])
    self.assertEqual(dtypes.float32,
                     dataset_ops.get_legacy_output_types(dataset)["bar"])
    self.assertEqual((), dataset_ops.get_legacy_output_shapes(dataset)["foo"])
    self.assertEqual((1,), dataset_ops.get_legacy_output_shapes(dataset)["bar"])

    for i in range(3):
      results = self.evaluate(get_next())
      self.assertEqual(components["foo"][i], results["foo"])
      self.assertEqual(components["bar"][i], results["bar"])
    with self.assertRaises(errors.OutOfRangeError):
      self.evaluate(get_next())
开发者ID:aritratony,项目名称:tensorflow,代码行数:18,代码来源:from_tensor_slices_test.py


示例14: batch_init_fn

 def batch_init_fn(_):
   indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
                                    0)
   return sparse_tensor.SparseTensor(
       indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
       values=constant_op.constant(
           [], shape=[0], dtype=dataset_ops.get_legacy_output_types(dataset)),
       dense_shape=array_ops.concat(
           [np.array([0], dtype=np.int64), padded_shape], 0))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:9,代码来源:batching.py


示例15: testIteratorStringHandle

  def testIteratorStringHandle(self):
    dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
    dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40])

    iterator_3 = dataset_ops.make_one_shot_iterator(dataset_3)
    iterator_4 = dataset_ops.make_one_shot_iterator(dataset_4)

    handle_placeholder = array_ops.placeholder(dtypes.string, shape=[])
    feedable_iterator = iterator_ops.Iterator.from_string_handle(
        handle_placeholder, dataset_ops.get_legacy_output_types(dataset_3),
        dataset_ops.get_legacy_output_shapes(dataset_3))
    next_element = feedable_iterator.get_next()

    self.assertTrue(dataset_ops.get_structure(dataset_3).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))
    self.assertTrue(dataset_ops.get_structure(dataset_4).is_compatible_with(
        dataset_ops.get_structure(feedable_iterator)))

    with self.cached_session() as sess:
      iterator_3_handle = sess.run(iterator_3.string_handle())
      iterator_4_handle = sess.run(iterator_4.string_handle())

      self.assertEqual(10,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(1,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(20,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(2,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(30,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      self.assertEqual(3,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_3_handle}))
      self.assertEqual(40,
                       sess.run(
                           next_element,
                           feed_dict={handle_placeholder: iterator_4_handle}))
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_3_handle})
      with self.assertRaises(errors.OutOfRangeError):
        sess.run(
            next_element, feed_dict={handle_placeholder: iterator_4_handle})
开发者ID:kylin9872,项目名称:tensorflow,代码行数:56,代码来源:iterator_test.py


示例16: __init__

  def __init__(self, filenames, compression_type=None, buffer_size=None,
               num_parallel_reads=None):
    """Creates a `TFRecordDataset` to read one or more TFRecord files.

    NOTE: The `num_parallel_reads` argument can be used to improve performance
    when reading from a remote filesystem.

    Args:
      filenames: A `tf.string` tensor or `tf.data.Dataset` containing one or
        more filenames.
      compression_type: (Optional.) A `tf.string` scalar evaluating to one of
        `""` (no compression), `"ZLIB"`, or `"GZIP"`.
      buffer_size: (Optional.) A `tf.int64` scalar representing the number of
        bytes in the read buffer. 0 means no buffering.
      num_parallel_reads: (Optional.) A `tf.int64` scalar representing the
        number of files to read in parallel. Defaults to reading files
        sequentially.

    Raises:
      TypeError: If any argument does not have the expected type.
      ValueError: If any argument does not have the expected shape.
    """
    if isinstance(filenames, dataset_ops.DatasetV2):
      if dataset_ops.get_legacy_output_types(filenames) != dtypes.string:
        raise TypeError(
            "`filenames` must be a `tf.data.Dataset` of `tf.string` elements.")
      if not dataset_ops.get_legacy_output_shapes(filenames).is_compatible_with(
          tensor_shape.scalar()):
        raise ValueError(
            "`filenames` must be a `tf.data.Dataset` of scalar `tf.string` "
            "elements.")
    else:
      filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string)
      filenames = array_ops.reshape(filenames, [-1], name="flat_filenames")
      filenames = dataset_ops.DatasetV2.from_tensor_slices(filenames)

    self._filenames = filenames
    self._compression_type = compression_type
    self._buffer_size = buffer_size
    self._num_parallel_reads = num_parallel_reads

    def read_one_file(filename):
      return _TFRecordDataset(filename, compression_type, buffer_size)

    if num_parallel_reads is None:
      self._impl = filenames.flat_map(read_one_file)
    else:
      self._impl = ParallelInterleaveDataset(
          filenames, read_one_file, cycle_length=num_parallel_reads,
          block_length=1, sloppy=False, buffer_output_elements=None,
          prefetch_input_elements=None)
    variant_tensor = self._impl._variant_tensor  # pylint: disable=protected-access
    super(TFRecordDatasetV2, self).__init__(variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:53,代码来源:readers.py


示例17: __init__

  def __init__(self, input_dataset, batch_size, row_shape):
    """See `Dataset.dense_to_sparse_batch()` for more details."""
    if not isinstance(
        dataset_ops.get_legacy_output_types(input_dataset), dtypes.DType):
      raise TypeError("DenseToSparseDataset requires an input whose elements "
                      "have a single component, whereas the input has %r." %
                      dataset_ops.get_legacy_output_types(input_dataset))
    self._input_dataset = input_dataset
    self._batch_size = batch_size
    self._row_shape = row_shape
    self._structure = structure.SparseTensorStructure(
        dataset_ops.get_legacy_output_types(input_dataset),
        tensor_shape.vector(None).concatenate(self._row_shape))

    variant_tensor = ged_ops.experimental_dense_to_sparse_batch_dataset(
        self._input_dataset._variant_tensor,  # pylint: disable=protected-access
        self._batch_size,
        row_shape=convert.partial_shape_to_tensor(self._row_shape),
        **dataset_ops.flat_structure(self))
    super(_DenseToSparseBatchDataset, self).__init__(input_dataset,
                                                     variant_tensor)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:21,代码来源:batching.py


示例18: _next_func

    def _next_func(string_handle):
      """Calls get_next for created iterator.

      Args:
        string_handle: An iterator string handle created by _init_func
      Returns:
        The elements generated from `input_dataset`
      """
      with ops.device(self._source_device_string):
        iterator = iterator_ops.Iterator.from_string_handle(
            string_handle,
            dataset_ops.get_legacy_output_types(self),
            dataset_ops.get_legacy_output_shapes(self),
            dataset_ops.get_legacy_output_classes(self))
      return self._element_structure._to_tensor_list(iterator.get_next())  # pylint: disable=protected-access
开发者ID:kylin9872,项目名称:tensorflow,代码行数:15,代码来源:prefetching_ops.py


示例19: __init__

  def __init__(self, selector_input, data_inputs):
    self._selector_input = selector_input
    self._data_inputs = list(data_inputs)

    first_output_types = dataset_ops.get_legacy_output_types(data_inputs[0])
    first_output_classes = dataset_ops.get_legacy_output_classes(data_inputs[0])

    for data_input in data_inputs[1:]:
      if (dataset_ops.get_legacy_output_types(data_input) != first_output_types
          or dataset_ops.get_legacy_output_classes(data_input)
          != first_output_classes):
        raise TypeError("All datasets must have the same type and class.")

    output_shapes = dataset_ops.get_legacy_output_shapes(self._data_inputs[0])
    for data_input in self._data_inputs[1:]:
      output_shapes = nest.pack_sequence_as(output_shapes, [
          ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
              nest.flatten(output_shapes),
              nest.flatten(dataset_ops.get_legacy_output_shapes(data_input)))
      ])

    self._structure = structure.convert_legacy_structure(
        first_output_types, output_shapes, first_output_classes)
    super(_DirectedInterleaveDataset, self).__init__()
开发者ID:aritratony,项目名称:tensorflow,代码行数:24,代码来源:interleave_ops.py


示例20: testEnumerate

  def testEnumerate(self):
    components = (["a", "b"], [1, 2], [37.0, 38])
    start = constant_op.constant(20, dtype=dtypes.int64)

    dataset = dataset_ops.Dataset.from_tensor_slices(components).enumerate(
        start)

    self.assertEqual(dtypes.int64,
                     dataset_ops.get_legacy_output_types(dataset)[0])
    dataset_output_shapes = dataset_ops.get_legacy_output_shapes(dataset)
    self.assertEqual((), dataset_output_shapes[0])
    self.assertEqual([tensor_shape.TensorShape([])] * 3,
                     [shape for shape in dataset_output_shapes[1]])

    self.assertDatasetProduces(dataset, [(20, (b"a", 1, 37.0)),
                                         (21, (b"b", 2, 38.0))])
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:enumerate_test.py



注:本文中的tensorflow.python.data.ops.dataset_ops.get_legacy_output_types函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap