本文整理汇总了Python中tensorflow.python.data.ops.dataset_ops.make_initializable_iterator函数的典型用法代码示例。如果您正苦于以下问题:Python make_initializable_iterator函数的具体用法?Python make_initializable_iterator怎么用?Python make_initializable_iterator使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了make_initializable_iterator函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testMultipleDatasetWithPrefixes
def testMultipleDatasetWithPrefixes(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset = dataset_transformation(dataset, aggregator, prefix="dataset1")
dataset2 = dataset_ops.Dataset.range(100).apply(
stats_ops.latency_stats("record_latency"))
dataset2 = dataset_transformation(dataset2, aggregator, prefix="dataset2")
iterator_0 = dataset_ops.make_initializable_iterator(dataset)
iterator_1 = dataset_ops.make_initializable_iterator(dataset2)
next_element = iterator_0.get_next() + iterator_1.get_next()
summary_t = aggregator.get_summary()
with self.test_session() as sess:
self.evaluate([iterator_0.initializer, iterator_1.initializer])
for i in range(100):
self.assertEqual(i * 2, self.evaluate(next_element))
self._assertSummaryHasCount(
self.evaluate(summary_t), "dataset1_record_latency", float(i + 1))
self._assertSummaryHasCount(
self.evaluate(summary_t), "dataset2_record_latency", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
self._assertSummaryHasCount(
self.evaluate(summary_t), "dataset1_record_latency", 100.0)
self._assertSummaryHasCount(
self.evaluate(summary_t), "dataset2_record_latency", 100.0)
开发者ID:aeverall,项目名称:tensorflow,代码行数:27,代码来源:stats_dataset_ops_test.py
示例2: testPrefetchToDeviceWithReInit
def testPrefetchToDeviceWithReInit(self):
host_dataset = dataset_ops.Dataset.range(10)
device_dataset = host_dataset.apply(
prefetching_ops.prefetch_to_device("/cpu:1"))
with ops.device("/cpu:1"):
iterator = dataset_ops.make_initializable_iterator(device_dataset)
next_element = iterator.get_next()
self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
dataset_ops.get_structure(device_dataset)))
self.assertTrue(dataset_ops.get_structure(host_dataset).is_compatible_with(
dataset_ops.get_structure(iterator)))
self.assertEqual(dtypes.int64, next_element.dtype)
self.assertEqual([], next_element.shape)
worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
with self.test_session(config=worker_config):
self.evaluate(iterator.initializer)
for i in range(5):
self.assertEqual(i, self.evaluate(next_element))
self.evaluate(iterator.initializer)
for i in range(10):
self.assertEqual(i, self.evaluate(next_element))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:prefetch_to_device_test.py
示例3: test_sequence_file_dataset
def test_sequence_file_dataset(self):
"""Test case for SequenceFileDataset.
The file is generated with `org.apache.hadoop.io.Text` for key/value.
There are 25 records in the file with the format of:
key = XXX
value = VALUEXXX
where XXX is replaced as the line number (starts with 001).
"""
filename = os.path.join(resource_loader.get_data_files_path(),
"testdata", "string.seq")
filenames = constant_op.constant([filename], dtypes.string)
num_repeats = 2
dataset = hadoop_dataset_ops.SequenceFileDataset(filenames).repeat(
num_repeats)
iterator = dataset_ops.make_initializable_iterator(dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for _ in range(num_repeats): # Dataset is repeated.
for i in range(25): # 25 records.
v0 = b"%03d" % (i + 1)
v1 = b"VALUE%03d" % (i + 1)
self.assertEqual((v0, v1), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:30,代码来源:hadoop_test.py
示例4: testSkipEagerSplitPipelineFailsWithPlacementError
def testSkipEagerSplitPipelineFailsWithPlacementError(self):
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
dataset = dataset_ops.Dataset.from_tensors(0)
# Define a pipeline that attempts to use variables on two
# different devices.
#
# Initialize the variables before creating to iterator, to avoid the
# placement algorithm overriding the DT_RESOURCE colocation constraints.
with ops.device("/cpu:0"):
var_0 = resource_variable_ops.ResourceVariable(initial_value=0)
dataset = dataset.map(lambda x: x + var_0.read_value())
sess.run(var_0.initializer)
with ops.device("/cpu:1"):
var_1 = resource_variable_ops.ResourceVariable(initial_value=0)
dataset = dataset.map(lambda x: x + var_1.read_value())
sess.run(var_1.initializer)
iterator = dataset_ops.make_initializable_iterator(dataset)
sess.run(iterator.initializer)
with self.assertRaisesRegexp(
errors.FailedPreconditionError,
"Error while reading resource variable Variable"):
sess.run(iterator.get_next())
开发者ID:aeverall,项目名称:tensorflow,代码行数:29,代码来源:from_tensors_test.py
示例5: getNext
def getNext(self, dataset, requires_initialization=False):
"""Returns a callable that returns the next element of the dataset.
Example use:
```python
# In both graph and eager modes
dataset = ...
get_next = self.getNext(dataset)
result = self.evaluate(get_next())
```
Args:
dataset: A dataset whose elements will be returned.
requires_initialization: Indicates that when the test is executed in graph
mode, it should use an initializable iterator to iterate through the
dataset (e.g. when it contains stateful nodes). Defaults to False.
Returns:
A callable that returns the next element of `dataset`.
"""
if context.executing_eagerly():
iterator = dataset.__iter__()
return iterator._next_internal # pylint: disable=protected-access
else:
if requires_initialization:
iterator = dataset_ops.make_initializable_iterator(dataset)
self.evaluate(iterator.initializer)
else:
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
return lambda: get_next
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:test_base.py
示例6: test_assert_element_shape
def test_assert_element_shape(self):
def create_dataset(_):
return (array_ops.ones(2, dtype=dtypes.float32),
array_ops.zeros((3, 4), dtype=dtypes.int32))
dataset = dataset_ops.Dataset.range(5).map(create_dataset)
expected_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((3, 4)))
self.assertEqual(expected_shapes,
dataset_ops.get_legacy_output_shapes(dataset))
result = dataset.apply(batching.assert_element_shape(expected_shapes))
self.assertEqual(expected_shapes,
dataset_ops.get_legacy_output_shapes(result))
iterator = dataset_ops.make_initializable_iterator(result)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:25,代码来源:assert_element_shape_test.py
示例7: testPrefetchBufferUtilization
def testPrefetchBufferUtilization(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(-1)
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
summary_t = aggregator.get_summary()
with self.cached_session() as sess:
self.evaluate(iterator.initializer)
for i in range(100):
self.assertAllEqual(
np.array([i] * i, dtype=np.int64), self.evaluate(next_element))
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
float(i + 1))
self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
0, 1)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
summary_str = self.evaluate(summary_t)
self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
100)
开发者ID:aeverall,项目名称:tensorflow,代码行数:26,代码来源:stats_dataset_ops_test.py
示例8: _testNumThreadsHelper
def _testNumThreadsHelper(self, num_threads, override_threadpool_fn):
def get_thread_id(_):
# Python creates a dummy thread object to represent the current
# thread when called from an "alien" thread (such as a
# `PrivateThreadPool` thread in this case). It does not include
# the TensorFlow-given display name, but it has a unique
# identifier that maps one-to-one with the underlying OS thread.
return np.array(threading.current_thread().ident).astype(np.int64)
dataset = (
dataset_ops.Dataset.range(1000).map(
lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
num_parallel_calls=32).apply(unique.unique()))
dataset = override_threadpool_fn(dataset)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
self.evaluate(iterator.initializer)
thread_ids = []
try:
while True:
thread_ids.append(self.evaluate(next_element))
except errors.OutOfRangeError:
pass
self.assertLen(thread_ids, len(set(thread_ids)))
self.assertNotEmpty(thread_ids)
if num_threads:
# NOTE(mrry): We don't control the thread pool scheduling, and
# so cannot guarantee that all of the threads in the pool will
# perform work.
self.assertLessEqual(len(thread_ids), num_threads)
开发者ID:aeverall,项目名称:tensorflow,代码行数:32,代码来源:override_threadpool_test.py
示例9: testSlideDatasetInvalid
def testSlideDatasetInvalid(self, count, window_size, window_shift,
window_stride):
count_t = array_ops.placeholder(dtypes.int64, shape=[])
window_size_t = array_ops.placeholder(dtypes.int64, shape=[])
window_shift_t = array_ops.placeholder(dtypes.int64, shape=[])
window_stride_t = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.range(10).map(lambda x: x).repeat(count_t).apply(
sliding.sliding_window_batch(
window_size=window_size_t,
window_shift=window_shift_t,
window_stride=window_stride_t)))
init_op = iterator.initializer
with self.cached_session() as sess:
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
init_op,
feed_dict={
count_t: count,
window_size_t: window_size,
window_shift_t: window_shift,
window_stride_t: window_stride
})
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:25,代码来源:slide_dataset_op_test.py
示例10: testTFRecordDatasetFromDataset
def testTFRecordDatasetFromDataset(self):
filenames = []
all_contents = []
for i in range(_NUM_FILES):
filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i)
filenames.append(filename)
writer = python_io.TFRecordWriter(filename)
for j in range(_NUM_ENTRIES):
record = compat.as_bytes('Record %d of file %d' % (j, i))
writer.write(record)
all_contents.append(record)
writer.close()
filenames = dataset_ops.Dataset.from_tensor_slices(filenames)
dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord')
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()
retrieved_values = []
for _ in range(4 * len(all_contents)):
retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
self.assertEqual(set(all_contents), set(retrieved_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:datasets_test.py
示例11: testArbitraryReaderFunc
def testArbitraryReaderFunc(self):
def MakeRecord(i, j):
return compat.as_bytes('%04d-%04d' % (i, j))
record_bytes = len(MakeRecord(10, 200))
all_contents = []
for i in range(_NUM_FILES):
filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i)
with open(filename, 'wb') as f:
for j in range(_NUM_ENTRIES):
record = MakeRecord(i, j)
f.write(record)
all_contents.append(record)
def FixedLengthFile(filename):
return readers.FixedLengthRecordDataset(filename, record_bytes)
dataset = datasets.StreamingFilesDataset(
os.path.join(self.get_temp_dir(), 'fixed_length*'),
filetype=FixedLengthFile)
with ops.device(self._worker_device):
iterator = dataset_ops.make_initializable_iterator(dataset)
self._sess.run(iterator.initializer)
get_next = iterator.get_next()
retrieved_values = []
for _ in range(4 * len(all_contents)):
retrieved_values.append(compat.as_bytes(self._sess.run(get_next)))
self.assertEqual(set(all_contents), set(retrieved_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:33,代码来源:datasets_test.py
示例12: _benchmarkRangeHelper
def _benchmarkRangeHelper(self, modeling_enabled):
num_elements = 10000000 if modeling_enabled else 50000000
# Use `Dataset.skip()` and `Dataset.take()` to perform the iteration in
# C++, and focus on the minimal overheads (excluding Python invocation
# costs).
dataset = dataset_ops.Dataset.range(num_elements).skip(
num_elements - 1).take(1)
options = dataset_ops.Options()
options.experimental_autotune = modeling_enabled
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
# Run once to warm up the session caches.
sess.run(iterator.initializer)
sess.run(next_element)
# Run once for timing.
sess.run(iterator.initializer)
start = time.time()
sess.run(next_element)
end = time.time()
time_per_element = (end - start) / num_elements
self.report_benchmark(
iters=num_elements,
wall_time=time_per_element,
name="modeling_%s" % ("on" if modeling_enabled else "off"))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:range_benchmark.py
示例13: test_assert_wrong_partial_element_shape_on_unknown_shape_dataset
def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
def create_unknown_shape_dataset(x):
return script_ops.py_func(
lambda _: ( # pylint: disable=g-long-lambda
np.ones(2, dtype=np.float32),
np.zeros((3, 4), dtype=np.int32)),
[x],
[dtypes.float32, dtypes.int32])
dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
unknown_shapes = (tensor_shape.TensorShape(None),
tensor_shape.TensorShape(None))
self.assertEqual(unknown_shapes, dataset.output_shapes)
wrong_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((None, 10)))
iterator = dataset_ops.make_initializable_iterator(
dataset.apply(batching.assert_element_shape(wrong_shapes)))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
开发者ID:xxg1413,项目名称:Tensorflow,代码行数:25,代码来源:assert_element_shape_test.py
示例14: test_assert_partial_element_shape
def test_assert_partial_element_shape(self):
def create_dataset(_):
return (array_ops.ones(2, dtype=dtypes.float32),
array_ops.zeros((3, 4), dtype=dtypes.int32))
dataset = dataset_ops.Dataset.range(5).map(create_dataset)
partial_expected_shape = (
tensor_shape.TensorShape(None), # Unknown shape
tensor_shape.TensorShape((None, 4))) # Partial shape
result = dataset.apply(
batching.assert_element_shape(partial_expected_shape))
# Partial shapes are merged with actual shapes:
actual_shapes = (tensor_shape.TensorShape(2),
tensor_shape.TensorShape((3, 4)))
self.assertEqual(actual_shapes, result.output_shapes)
iterator = dataset_ops.make_initializable_iterator(result)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for _ in range(5):
sess.run(get_next)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:xxg1413,项目名称:Tensorflow,代码行数:26,代码来源:assert_element_shape_test.py
示例15: testMapAndBatchSparse
def testMapAndBatchSparse(self, numa_aware):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
dataset = dataset_ops.Dataset.range(10).apply(
batching.map_and_batch(_sparse, 5))
if numa_aware:
options = dataset_ops.Options()
options.experimental_numa_aware = True
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
for i in range(2):
actual = self.evaluate(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
dense_shape=[5, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next)
开发者ID:aeverall,项目名称:tensorflow,代码行数:29,代码来源:map_and_batch_test.py
示例16: testSlideSparse
def testSlideSparse(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=[[0]], values=(i * [1]), dense_shape=[1])
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.range(10).map(_sparse).apply(
sliding.sliding_window_batch(window_size=5, window_shift=3)))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
actual = sess.run(get_next)
expected = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
values=[i * 3, i * 3 + 1, i * 3 + 2, i * 3 + 3, i * 3 + 4],
dense_shape=[5, 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:25,代码来源:slide_dataset_op_test.py
示例17: testMapAndBatchShapeMismatch
def testMapAndBatchShapeMismatch(self, numa_aware):
"""Test a dataset that maps a TF function across its input elements."""
def generator():
yield [1]
yield [2]
yield [3]
yield [[4, 5, 6]]
dataset = dataset_ops.Dataset.from_generator(
generator, output_types=dtypes.int32)
batch_size = 4
dataset = dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
if numa_aware:
options = dataset_ops.Options()
options.experimental_numa_aware = True
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
self.evaluate(init_op)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"number of elements does not match"):
self.evaluate(get_next)
开发者ID:aeverall,项目名称:tensorflow,代码行数:26,代码来源:map_and_batch_test.py
示例18: testSlideSparseWithDifferentDenseShapes
def testSlideSparseWithDifferentDenseShapes(self):
def _sparse(i):
return sparse_tensor.SparseTensorValue(
indices=array_ops.expand_dims(
math_ops.range(i, dtype=dtypes.int64), 1),
values=array_ops.fill([math_ops.to_int32(i)], i),
dense_shape=[i])
iterator = dataset_ops.make_initializable_iterator(
dataset_ops.Dataset.range(10).map(_sparse).apply(
sliding.sliding_window_batch(window_size=5, window_shift=3)))
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
num_batches = (10 - 5) // 3 + 1
for i in range(num_batches):
actual = sess.run(get_next)
expected_indices = []
expected_values = []
for j in range(5):
for k in range(i * 3 + j):
expected_indices.append([j, k])
expected_values.append(i * 3 + j)
expected = sparse_tensor.SparseTensorValue(
indices=expected_indices,
values=expected_values,
dense_shape=[5, i * 3 + 5 - 1])
self.assertTrue(sparse_tensor.is_sparse(actual))
self.assertSparseValuesEqual(actual, expected)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:34,代码来源:slide_dataset_op_test.py
示例19: testFilteredElementsStats
def testFilteredElementsStats(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(101).filter(
lambda x: math_ops.equal(math_ops.mod(x, 3), 0))
dataset = dataset_transformation(dataset, aggregator)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
summary_t = aggregator.get_summary()
with self.test_session() as sess:
self.evaluate(iterator.initializer)
for i in range(34):
self.assertEqual(i * 3, self.evaluate(next_element))
if i is not 0:
self._assertSummaryHasScalarValue(
self.evaluate(summary_t), "Filter::dropped_elements",
float(i * 2))
self._assertSummaryHasScalarValue(
self.evaluate(summary_t), "Filter::filtered_elements", float(i + 1))
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element)
self._assertSummaryHasScalarValue(
self.evaluate(summary_t), "Filter::dropped_elements", 67.0)
self._assertSummaryHasScalarValue(
self.evaluate(summary_t), "Filter::filtered_elements", 34.0)
开发者ID:aeverall,项目名称:tensorflow,代码行数:25,代码来源:stats_dataset_ops_test.py
示例20: benchmarkOldUnbatchImplementation
def benchmarkOldUnbatchImplementation(self):
batch_sizes = [1, 2, 5, 10, 20, 50]
elems_per_trial = 10000
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
dataset = dataset.batch(batch_size_placeholder)
dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
dataset = dataset.skip(elems_per_trial)
options = dataset_ops.Options()
options.experimental_optimization.apply_default_optimizations = False
dataset = dataset.with_options(options)
iterator = dataset_ops.make_initializable_iterator(dataset)
next_element = iterator.get_next()
with session.Session() as sess:
for batch_size in batch_sizes:
deltas = []
for _ in range(5):
sess.run(
iterator.initializer,
feed_dict={batch_size_placeholder: batch_size})
start = time.time()
sess.run(next_element.op)
end = time.time()
deltas.append((end - start) / elems_per_trial)
median_wall_time = np.median(deltas)
self.report_benchmark(
iters=10000,
wall_time=median_wall_time,
name="unfused_batch_size_%d" %
batch_size)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:33,代码来源:unbatch_benchmark.py
注:本文中的tensorflow.python.data.ops.dataset_ops.make_initializable_iterator函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论