本文整理汇总了Python中tensorflow.python.distribute.values.select_replica函数的典型用法代码示例。如果您正苦于以下问题:Python select_replica函数的具体用法?Python select_replica怎么用?Python select_replica使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了select_replica函数的16个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _test_input_fn_iterator
def _test_input_fn_iterator(self,
iterator,
devices,
expected_values,
sess=None,
test_reinitialize=True):
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(iterator.initialize())
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if test_reinitialize:
evaluate(iterator.initialize())
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate([
values.select_replica(r, next_element) for r in range(len(devices))
])
self.assertEqual(expected_value, computed_value)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:30,代码来源:strategy_test_lib.py
示例2: _test_input_fn_iterator
def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn,
expected_values):
distribution, master_target, config = self._get_test_object(
task_type, task_id, num_gpus)
devices = distribution.extended.worker_devices
with ops.Graph().as_default(), \
self.cached_session(config=config,
target=master_target) as sess:
iterator = distribution.make_input_fn_iterator(input_fn)
sess.run(iterator.initialize())
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
sess.run(iterator.initialize())
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
self.assertEqual(expected_value, computed_value)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:31,代码来源:collective_all_reduce_strategy_test.py
示例3: testWrapClass
def testWrapClass(self):
# Normally a mirrored value would be the same across devices, but
# for a test it is convenient to be able to tell the values apart.
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map,
(_nested_value("1"), _nested_value("2")),
values.Mirrored)
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
self._is_per_replica(result[0], ["a1", "a2"], values.Mirrored)
self._is_per_replica(result[2], ["h1", "h2"], values.Mirrored)
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
self._is_per_replica(result[1][0], ["b1", "b2"], values.Mirrored)
self._is_per_replica(result[1][2], ["g1", "g2"], values.Mirrored)
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_replica(result[1][1]["c"], ["d1", "d2"], values.Mirrored)
self._is_per_replica(result[1][1]["e"], ["f1", "f2"], values.Mirrored)
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica(1, result))
# Values are marked as mirrored, so select_device_mirrored() is allowed.
self.assertEqual(_nested_value("1"),
values.select_device_mirrored(_device_str(0), result))
self.assertEqual(_nested_value("2"),
values.select_device_mirrored(_device_str(1), result))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:32,代码来源:values_test.py
示例4: testNested
def testNested(self):
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map,
(_nested_value("1"), _nested_value("2")))
self.assertIsInstance(result, tuple)
self.assertEqual(3, len(result))
self._is_per_replica(result[0], ["a1", "a2"])
self._is_per_replica(result[2], ["h1", "h2"])
self.assertIsInstance(result[1], list)
self.assertEqual(3, len(result[1]))
self._is_per_replica(result[1][0], ["b1", "b2"])
self._is_per_replica(result[1][2], ["g1", "g2"])
self.assertIsInstance(result[1][1], dict)
self.assertEqual(set(["c", "e"]), set(result[1][1].keys()))
self._is_per_replica(result[1][1]["c"], ["d1", "d2"])
self._is_per_replica(result[1][1]["e"], ["f1", "f2"])
# Also test that we can undo the merge using select_replica()
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
self.assertEqual(_nested_value("2"),
values.select_replica(1, result))
# select_device_mirrored() should fail due to non-mirrored values
with self.assertRaises(TypeError):
values.select_device_mirrored(_device_str(0), result)
with self.assertRaises(TypeError):
values.select_device_mirrored(_device_str(1), result)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:29,代码来源:values_test.py
示例5: _tpu_run
def _tpu_run(strategy, fn, args, kwargs):
"""Common implementation of TPUStrategy.experimental_run_v2."""
if context.executing_eagerly() and not ops.inside_function():
raise NotImplementedError(
"Eager mode not supported in TPUStrategy outside TF functions.")
if kwargs is None:
kwargs = {}
# Used to re-structure flattened output tensors from `tpu.replicate()`
# into a structured format.
result = [[]]
def replicated_fn(replica_id, replica_args, replica_kwargs):
"""Wraps user function to provide replica ID and `Tensor` inputs."""
with _TPUReplicaContext(strategy, replica_id_in_sync_group=replica_id):
result[0] = fn(*replica_args, **replica_kwargs)
return result[0]
replicate_inputs = [] # By replica.
for i in range(strategy.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
values.select_replica(i, args),
values.select_replica(i, kwargs)])
# Construct and pass `maximum_shapes` so that we could support dynamic
# shapes using dynamic padder.
if replicate_inputs:
maximum_shapes = []
flattened_list = nest.flatten(replicate_inputs[0])
for input_tensor in flattened_list:
maximum_shapes.append(input_tensor.get_shape())
maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
maximum_shapes)
else:
maximum_shapes = None
with strategy.scope():
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs,
maximum_shapes=maximum_shapes)
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):
result[0] = [
output for output in result[0] if tensor_util.is_tensor(output)
]
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
replicate_outputs = [
nest.pack_sequence_as(result[0], nest.flatten(replica_output))
for replica_output in replicate_outputs
]
device_map = strategy.extended._device_map # pylint: disable=protected-access
return values.regroup(device_map, replicate_outputs)
开发者ID:aritratony,项目名称:tensorflow,代码行数:56,代码来源:tpu_strategy.py
示例6: testNamedTupleEstimatorSpec
def testNamedTupleEstimatorSpec(self):
with context.graph_mode(), ops.Graph().as_default():
devices = []
created_estimator_specs = []
for device_id in range(3):
spec = model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.TRAIN,
loss=constant_op.constant(device_id / 2),
train_op=array_ops.identity(constant_op.constant(device_id)))
devices.append(_device_str(device_id))
created_estimator_specs.append(spec)
device_map = values.ReplicaDeviceMap(devices)
merged_estimator_spec = values.regroup(
device_map, created_estimator_specs)
self.assertTrue(
isinstance(merged_estimator_spec, model_fn_lib.EstimatorSpec))
self.assertEqual(model_fn_lib.ModeKeys.TRAIN, merged_estimator_spec.mode)
for device_id in range(3):
d = _device_str(device_id)
self.assertEqual(created_estimator_specs[device_id].loss,
merged_estimator_spec.loss.get(d))
self.assertEqual(created_estimator_specs[device_id].train_op,
merged_estimator_spec.train_op.get(d))
# Scaffold is populated by `EstimatorSpec.__new__`.
self.assertEqual(created_estimator_specs[device_id].scaffold,
merged_estimator_spec.scaffold.get(d))
# Also test that we can undo the merge using select_replica()
self.assertEqual(created_estimator_specs[device_id],
values.select_replica(device_id,
merged_estimator_spec))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:33,代码来源:values_test.py
示例7: _test_iterator
def _test_iterator(self, sess, iterator, devices, expected_values):
next_element = iterator.get_next()
for r, device in enumerate(devices):
v = values.select_replica(r, next_element)
# The `v` here can be a tuple.
for element in nest.flatten(v):
self.assertTrue(element.device in device)
for expected_value in expected_values:
t = [values.select_replica(r, next_element) for r in range(len(devices))]
actual = sess.run(t)
self.assertEqual(expected_value, actual)
with self.assertRaises(errors.OutOfRangeError):
sess.run([values.select_replica(r, next_element)
for r in range(len(devices))])
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:16,代码来源:input_lib_test.py
示例8: experimental_run_v2
def experimental_run_v2(self, fn, args=(), kwargs=None):
"""See base class."""
if context.executing_eagerly() and not ops.inside_function():
raise NotImplementedError(
"Eager mode not supported in TPUStrategy outside TF functions.")
if kwargs is None:
kwargs = {}
# Used to re-structure flattened output tensors from `tpu.replicate()`
# into a structured format.
result = [[]]
def replicated_fn(replica_id, replica_args, replica_kwargs):
"""Wraps user function to provide replica ID and `Tensor` inputs."""
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
result[0] = fn(*replica_args, **replica_kwargs)
return result[0]
replicate_inputs = [] # By replica.
for i in range(self.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
values.select_replica(i, args),
values.select_replica(i, kwargs)])
with self.scope():
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
# Remove all no ops that may have been added during 'tpu.replicate()'
if isinstance(result[0], list):
result[0] = [
output for output in result[0] if tensor_util.is_tensor(output)
]
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
replicate_outputs = [
nest.pack_sequence_as(result[0], nest.flatten(replica_output))
for replica_output in replicate_outputs
]
device_map = self.extended._device_map # pylint: disable=protected-access
return values.regroup(device_map, replicate_outputs)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:43,代码来源:tpu_strategy.py
示例9: testSameId
def testSameId(self):
foo = object()
device_map = values.ReplicaDeviceMap((_device_str(0), _device_str(1)))
result = values.regroup(device_map, (("a", foo), ("b", foo)))
self.assertIsInstance(result, tuple)
self.assertEqual(2, len(result))
self._is_per_replica(result[0], ["a", "b"])
self.assertIs(foo, result[1])
# Test select_replica(), should undo the merge done by regroup().
result_0 = values.select_replica(0, result)
self.assertIsInstance(result_0, tuple)
self.assertEqual(2, len(result_0))
self.assertEqual("a", result_0[0])
self.assertIs(foo, result_0[1])
result_1 = values.select_replica(1, result)
self.assertIsInstance(result_1, tuple)
self.assertEqual(2, len(result_1))
self.assertEqual("b", result_1[0])
self.assertIs(foo, result_1[1])
开发者ID:kylin9872,项目名称:tensorflow,代码行数:20,代码来源:values_test.py
示例10: _test_iterator
def _test_iterator(self, input_type, dataset_fn, worker_device_pairs,
expected_values, sess=None, split_batch_by=None):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
if input_type == "input_fn":
input_contexts = [
distribute_lib.InputContext() for _ in worker_device_pairs]
input_fn = lambda _: dataset_fn()
iterator = input_lib.InputFunctionIterator(
input_fn, input_workers, input_contexts)
else:
iterator = input_lib.DatasetIterator(
dataset_fn(), input_workers, split_batch_by)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertAllEqual(expected_value, computed_value)
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate([values.select_replica(r, next_element)
for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertAllEqual(expected_value, computed_value)
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:39,代码来源:input_lib_test.py
示例11: _test_iterator
def _test_iterator(self,
input_type,
dataset_fn,
worker_device_pairs,
expected_values,
sess=None,
split_batch_by=None,
enable_get_next_as_optional=False):
devices = nest.flatten([ds for _, ds in worker_device_pairs])
iterator = self._create_iterator(
input_type, dataset_fn, worker_device_pairs, devices, split_batch_by,
enable_get_next_as_optional)
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
evaluate(control_flow_ops.group(iterator.initialize()))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r, next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
开发者ID:VonChenPlus,项目名称:tensorflow,代码行数:39,代码来源:input_lib_test.py
示例12: testOneDevice
def testOneDevice(self):
device_map = values.ReplicaDeviceMap((_device_str(0),))
result = values.regroup(device_map, (_nested_value("1"),))
# On one device regroup() and select_replica() are basically identity.
self.assertEqual(_nested_value("1"), result)
self.assertEqual(_nested_value("1"),
values.select_replica(0, result))
# The one exception has to do with MirroredVariables.
d = "/device:CPU:0"
with ops.device(d):
v = variable_scope.get_variable(
name="v", initializer=1., use_resource=True)
device_map = values.ReplicaDeviceMap((d,))
mirrored = values.MirroredVariable(None, device_map, (v,),
variable_scope.VariableAggregation.SUM)
result = values.regroup(device_map, (v,))
self.assertIs(mirrored, result)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:18,代码来源:values_test.py
示例13: experimental_run
def experimental_run(self, fn, input_iterator=None):
"""See base class."""
if context.executing_eagerly():
raise NotImplementedError("Eager mode not supported in TPUStrategy.")
if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access
raise NotImplementedError(
"`experimental_run` is not compatible with "
"`_disable_training_loop_on_host=True`")
if input_iterator is None:
inputs = []
else:
inputs = input_iterator.get_next()
result = [None]
def replicated_fn(replica_id, inputs):
"""Wraps user function to provide replica ID and `Tensor` inputs."""
with _TPUReplicaContext(self, replica_id_in_sync_group=replica_id):
if input_iterator is None:
result[0] = fn()
else:
result[0] = fn(inputs)
return result[0]
replicate_inputs = [] # By replica.
for i in range(self.num_replicas_in_sync):
replicate_inputs.append(
[constant_op.constant(i, dtype=dtypes.int32),
values.select_replica(i, inputs)])
with self.scope():
replicate_outputs = tpu.replicate(replicated_fn, replicate_inputs)
# Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
replicate_outputs = [
nest.pack_sequence_as(result[0], nest.flatten(replica_outputs))
for replica_outputs in replicate_outputs]
device_map = self.extended._device_map # pylint: disable=protected-access
return values.regroup(device_map, replicate_outputs)
开发者ID:AndreasGocht,项目名称:tensorflow,代码行数:41,代码来源:tpu_strategy.py
示例14: rewrite_fn
def rewrite_fn(*args):
"""The rewritten step fn running on TPU."""
del args
per_replica_inputs = multi_worker_iterator.get_next()
replicate_inputs = []
for replica_id in range(self._num_replicas_in_sync):
select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop
replicate_inputs.append((nest.map_structure(
select_replica, per_replica_inputs),))
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
# will flatten it in this case. If run_fn has no tensor outputs,
# tpu.replicate returns a list of no_ops, we will keep the output as it
# is.
if isinstance(replicate_outputs[0], list):
replicate_outputs = nest.flatten(replicate_outputs)
return replicate_outputs
开发者ID:jackd,项目名称:tensorflow,代码行数:21,代码来源:tpu_strategy.py
示例15: _call_for_each_replica
def _call_for_each_replica(distribution, device_map, fn, args, kwargs):
"""Run `fn` in separate threads, once per replica/worker device.
Args:
distribution: the DistributionStrategy object.
device_map: the DeviceMap with the devices to run `fn` on.
fn: function to run (will be run once per replica, each in its own thread).
args: positional arguments for `fn`
kwargs: keyword arguments for `fn`.
Returns:
Merged return value of `fn` across all replicas.
Raises:
RuntimeError: If fn() calls get_replica_context().merge_call() a different
number of times from the available devices.
"""
# TODO(josh11b): Add this option once we add synchronization to variable
# creation. Until then, this is pretty unsafe to use.
run_concurrently = False
if not context.executing_eagerly():
# Needed for per-thread device, etc. contexts in graph mode.
ops.get_default_graph().switch_to_thread_local()
coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
shared_variable_store = {}
# TODO(isaprykin): Create these threads once instead of during every call.
threads = []
for index in range(device_map.num_replicas_in_graph):
variable_creator_fn = shared_variable_creator.make_fn(
shared_variable_store, index)
t = _MirroredReplicaThread(
distribution, coord, index, device_map, variable_creator_fn, fn,
values.select_replica(index, args),
values.select_replica(index, kwargs))
threads.append(t)
for t in threads:
t.start()
# When `fn` starts `should_run` event is set on _MirroredReplicaThread
# (`MRT`) threads. The execution waits until
# `MRT.has_paused` is set, which indicates that either `fn` is
# complete or a `get_replica_context().merge_call()` is called. If `fn` is
# complete, then `MRT.done` is set to True. Otherwise, arguments
# of `get_replica_context().merge_call` from all paused threads are grouped
# and the `merge_fn` is performed. Results of the
# `get_replica_context().merge_call` are then set to `MRT.merge_result`.
# Each such `get_replica_context().merge_call` call returns the
# `MRT.merge_result` for that thread when `MRT.should_run` event
# is reset again. Execution of `fn` resumes.
try:
with coord.stop_on_exception():
all_done = False
while not all_done and not coord.should_stop():
done = []
if run_concurrently:
for t in threads:
t.should_run.set()
for t in threads:
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
else:
for t in threads:
t.should_run.set()
t.has_paused.wait()
t.has_paused.clear()
if coord.should_stop():
return None
done.append(t.done)
if coord.should_stop():
return None
all_done = all(done)
if not all_done:
if any(done):
raise RuntimeError("Some replicas made a different number of "
"replica_context().merge_call() calls.")
# get_replica_context().merge_call() case
merge_args = values.regroup(
device_map, tuple(t.merge_args for t in threads))
merge_kwargs = values.regroup(
device_map, tuple(t.merge_kwargs for t in threads))
# We capture the name_scope of the MRT when we call merge_fn
# to ensure that if we have opened a name scope in the MRT,
# it will be respected when executing the merge function. We only
# capture the name_scope from the first MRT and assume it is
# the same for all other MRTs.
mtt_captured_name_scope = threads[0].captured_name_scope
# Capture and merge the control dependencies from all the threads.
mtt_captured_control_deps = set()
for t in threads:
mtt_captured_control_deps.update(t.captured_control_deps)
with ops.name_scope(mtt_captured_name_scope),\
ops.control_dependencies(mtt_captured_control_deps):
#.........这里部分代码省略.........
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:101,代码来源:mirrored_strategy.py
示例16: _test_input_iteration
def _test_input_iteration(self,
input_type,
api_type,
iteration_type,
dataset_fn,
worker_device_pairs,
expected_values,
sess=None,
split_batch_by=None,
enable_get_next_as_optional=False):
if iteration_type == "for_loop" and not context.executing_eagerly():
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_iterator" and iteration_type == "for_loop":
self.skipTest("unsupported test combination.")
if api_type == "wrap_into_dataset" and input_type == "input_fn":
self.skipTest("unsupported test combination.")
devices = nest.flatten([ds for _, ds in worker_device_pairs])
device_map = values.ReplicaDeviceMap(devices)
input_workers = input_lib.InputWorkers(device_map, worker_device_pairs)
if api_type == "wrap_into_iterator":
iterator = self._wrap_iterator(
input_type, dataset_fn, input_workers, devices, split_batch_by,
enable_get_next_as_optional)
else:
# wrapping into a dataset:
given_dataset = dataset_fn(distribute_lib.InputContext())
dataset = self._wrap_dataset(input_type, given_dataset, input_workers,
split_batch_by, enable_get_next_as_optional)
if context.executing_eagerly():
iterator = iter(dataset)
else:
# The dataset can be a tf.data.DatasetV1Adapter instance since we wrap
# tf.data.DatasetV1 as a tf.data.DatasetV1Adapter instance when we
# autoshard the dataset.
if not isinstance(dataset, (dataset_ops.DatasetV1,
dataset_ops.DatasetV1Adapter)):
iterator = iter(dataset)
else:
iterator = dataset.make_one_shot_iterator()
if iteration_type == "get_next":
evaluate = lambda x: sess.run(x) if sess else self.evaluate(x)
if isinstance(iterator, input_lib.DistributedIteratorV1):
evaluate(control_flow_ops.group(iterator.initialize()))
else:
evaluate(control_flow_ops.group(iterator._initializer))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
# After re-initializing the iterator, should be able to iterate again.
if isinstance(iterator, input_lib.DistributedIteratorV1):
evaluate(control_flow_ops.group(iterator.initialize()))
else:
evaluate(control_flow_ops.group(iterator._initializer))
for expected_value in expected_values:
next_element = iterator.get_next()
computed_value = evaluate(
[values.select_replica(r,
next_element) for r in range(len(devices))])
self.assertEqual(len(expected_value), len(computed_value))
for i in range(len(expected_value)):
self.assertAllEqual(expected_value[i], computed_value[i])
if iteration_type == "for_loop" and context.executing_eagerly():
actual_values = []
for x in dataset:
computed_value = self.evaluate(
[values.select_replica(r, x) for r in range(len(devices))])
actual_values.append(computed_value)
for i, expected_value in enumerate(expected_values):
self.assertEqual(len(expected_value), len(actual_values[i]))
for j in range(len(expected_value)):
self.assertAllEqual(expected_value[j], actual_values[i][j])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:92,代码来源:input_lib_test.py
注:本文中的tensorflow.python.distribute.values.select_replica函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论