本文整理汇总了Python中tensorflow.python.distribute.distribution_strategy_context.get_replica_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_replica_context函数的具体用法?Python get_replica_context怎么用?Python get_replica_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_replica_context函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: set_non_tensor_output
def set_non_tensor_output(self, name, output):
"""Set `output` with `name` to be captured as a non tensor output."""
if distribution_strategy_context.in_cross_replica_context():
self._non_tensor_outputs[name] = output
else:
def merge_fn(distribution, value):
# NOTE(priyag): For non tensor outputs, we simply return all the values
# in a list as reduction doesn't make sense on non tensors.
self._non_tensor_outputs[name] = distribution.unwrap(value)
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
开发者ID:kylin9872,项目名称:tensorflow,代码行数:11,代码来源:input_lib.py
示例2: _assert_in_default_state
def _assert_in_default_state(t):
t.assertIs(ds_context._get_default_replica_context(),
ds_context.get_replica_context())
t.assertIs(None, ds_context.get_cross_replica_context())
t.assertFalse(ds_context.in_cross_replica_context())
t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
t.assertFalse(ds_context.has_strategy())
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:distribute_lib_test.py
示例3: decorated
def decorated(_, *args):
"""Decorated function with merge_call."""
replica_context = distribution_strategy_context.get_replica_context()
if replica_context is None: # if in cross replica context already
result_t = array_ops.identity(result_fn(*args))
else:
# TODO(psv): Test distribution of metrics using different distribution
# strategies.
# Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
# with distribution object as the first parameter. We create a wrapper
# here so that the result function need not have that parameter.
def merge_fn_wrapper(distribution, merge_fn, *args):
# We will get `PerReplica` merge function. Taking the first one as all
# are identical copies of the function that we had passed below.
merged_result_fn = (
distribution.experimental_local_results(merge_fn)[0](*args))
# Wrapping result in identity so that control dependency between
# update_op from `update_state` and result works in case result returns
# a tensor.
return array_ops.identity(merged_result_fn)
# Wrapping result in merge_call. merge_call is used when we want to leave
# replica mode and compute a value in cross replica mode.
result_t = replica_context.merge_call(
merge_fn_wrapper, args=(result_fn,) + args)
return result_t
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:metrics_utils.py
示例4: apply_gradients
def apply_gradients(self, grads_and_vars, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
"""
grads_and_vars = _filter_grads(grads_and_vars)
var_list = [v for (_, v) in grads_and_vars]
self._create_hypers()
with ops.init_scope():
self._create_slots(var_list)
self._prepare(var_list)
return distribute_ctx.get_replica_context().merge_call(
self._distributed_apply, args=(grads_and_vars,), kwargs={"name": name})
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:30,代码来源:optimizer_v2.py
示例5: merge_fn
def merge_fn(dist, s):
self.assertIs(ds_context._get_default_strategy(), dist)
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertIs(dist, ds_context.get_strategy())
self.assertFalse(ds_context.has_strategy())
return "foo_" + s
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:distribute_lib_test.py
示例6: merge_grads
def merge_grads(grads_and_vars):
"""Merge gradients from different replicas."""
def merge_grad_fn(strategy, grads_and_vars):
reduced_grads = strategy.extended.batch_reduce_to(
ds_reduce_util.ReduceOp.SUM, grads_and_vars)
return reduced_grads
return distribute_ctx.get_replica_context().merge_call(
merge_grad_fn, args=(grads_and_vars,))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:optimizer_v2.py
示例7: set_last_step_output
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
axis=None)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
axis=None)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:40,代码来源:input_lib.py
示例8: run_fn
def run_fn():
replica_context = ds_context.get_replica_context()
self.assertTrue(replica_context is not None)
self.assertIs(None, ds_context.get_cross_replica_context())
self.assertFalse(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())
self.assertIs(dist, ds_context.get_strategy())
self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="bar"))
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py
示例9: _test_run
def _test_run(self, strategy):
out1 = strategy.experimental_run_v2(
lambda: ds_context.get_replica_context().replica_id_in_sync_group + 1)
self.assertAllEqual([1, 2], self.evaluate(strategy.unwrap(out1)))
out2 = strategy.experimental_run_v2(
lambda x: {"a": x * 2, "b": x * x}, args=(out1,))
out2_vals = self.evaluate(nest.map_structure(strategy.unwrap, out2))
self.assertAllEqual([2, 4], out2_vals["a"])
self.assertAllEqual([1, 4], out2_vals["b"])
out3 = strategy.experimental_run_v2(lambda b, a: a + 2 * b + 2, kwargs=out2)
self.assertAllEqual([6, 14], self.evaluate(strategy.unwrap(out3)))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:13,代码来源:strategy_test_lib.py
示例10: merge_update_step
def merge_update_step(update_ops, local_step):
"""Merge local step counter update from different replicas."""
def merge_update_step_fn(strategy, update_ops, local_step):
merged_ops = []
for update_op in update_ops:
merged_ops.append(strategy.group(update_op))
with ops.control_dependencies(merged_ops):
incre_op = local_step.assign_add(1).op
return incre_op
return distribute_ctx.get_replica_context().merge_call(
merge_update_step_fn, args=(update_ops, local_step))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:13,代码来源:optimizer_v2.py
示例11: testScope
def testScope(self):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())
self.assertIs(dist, ds_context.get_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:distribute_lib_test.py
示例12: testMergeCall
def testMergeCall(self):
_assert_in_default_state(self)
def merge_fn(dist, s):
self.assertIs(ds_context._get_default_strategy(), dist)
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertIs(dist, ds_context.get_strategy())
self.assertFalse(ds_context.has_strategy())
return "foo_" + s
replica_ctx = ds_context.get_replica_context()
self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:distribute_lib_test.py
示例13: _test_step_fn
def _test_step_fn(inputs):
"""A fn that returns output of single test step."""
if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
inputs, targets = inputs
else:
targets = None
(distribution_strategy_context.get_replica_context().merge_call(
_build_model, args=(model, mode, inputs, targets)))
(_, outputs, updates, _) = (
_per_replica_execution_function(
distributed_training_utils.get_distributed_model(model, mode),
mode))
with ops.control_dependencies([updates]):
return outputs
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:training_distributed.py
示例14: skip_summary
def skip_summary():
"""Determines if summary should be skipped.
If using multiple replicas in distributed strategy, skip summaries on all
replicas except the first one (replica_id=0).
Returns:
True if the summary is skipped; False otherwise.
"""
# TODO(priyag): Add a new optional argument that will provide multiple
# alternatives to override default behavior. (e.g. run on last replica,
# compute sum or mean across replicas).
replica_context = distribution_strategy_context.get_replica_context()
if not replica_context:
return False
# TODO(b/118385803): when replica_id of _TPUReplicaContext is properly
# initialized, remember to change here as well.
replica_id = replica_context.replica_id_in_sync_group
if isinstance(replica_id, ops.Tensor):
replica_id = tensor_util.constant_value(replica_id)
return replica_id and replica_id > 0
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:22,代码来源:summary_op_util.py
示例15: decorated
def decorated(_, *args):
"""Decorated function with merge_call."""
replica_context = distribution_strategy_context.get_replica_context()
if replica_context is None: # if in cross replica context already
result_t = result_fn(*args)
else:
# TODO(psv): Test distribution of metrics using different distribution
# strategies.
# Creating a wrapper for merge_fn. merge_call invokes the given merge_fn
# with distribution object as the first parameter. We create a wrapper
# here so that the result function need not have that parameter.
def merge_fn_wrapper(distribution, merge_fn, *args):
# We will get `PerDevice` merge function. Taking the first one as all
# are identical copies of the function that we had passed below.
return distribution.unwrap(merge_fn)[0](*args)
# Wrapping result in merge_call. merge_call is used when we want to leave
# replica mode and compute a value in cross replica mode.
result_t = replica_context.merge_call(
merge_fn_wrapper, args=(result_fn,) + args)
return result_t
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:22,代码来源:metrics_utils.py
示例16: mark_devices_fn
def mark_devices_fn():
replica_id = self.evaluate(
ds_context.get_replica_context().replica_id_in_sync_group)
self.assertLess(replica_id, len(d.extended.worker_devices))
self.assertFalse(expected_devices[replica_id])
expected_devices[replica_id] = True
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:6,代码来源:strategy_test_lib.py
示例17: apply_gradients
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
applies gradients.
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the
name passed to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients. If `global_step`
was not None, that operation also increments `global_step`.
Raises:
TypeError: If `grads_and_vars` is malformed.
ValueError: If none of the variables have gradients.
RuntimeError: If you should use `_distributed_apply()` instead.
"""
# This is a default implementation of apply_gradients() that can be shared
# by most optimizers. It relies on the subclass implementing the following
# methods: _create_slots(), _prepare(), _apply_dense(), and _apply_sparse().
# TODO(isaprykin): Get rid of `has_strategy()` check by
# always calling _distributed_apply(), using the default distribution
# as needed.
if distribute_ctx.has_strategy():
# Handle DistributionStrategy case.
if distribute_ctx.in_cross_replica_context():
raise RuntimeError("Use `_distributed_apply()` instead of "
"`apply_gradients()` in a cross-replica context.")
grads_and_vars = get_filtered_grad_fn(lambda: grads_and_vars)()
return distribute_ctx.get_replica_context().merge_call(
self._distributed_apply, args=(grads_and_vars, global_step, name))
# No DistributionStrategy case.
grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
if not grads_and_vars:
raise ValueError("No variables provided.")
converted_grads_and_vars = []
for g, v in grads_and_vars:
if g is not None:
try:
# Convert the grad to Tensor or IndexedSlices if necessary.
g = ops.convert_to_tensor_or_indexed_slices(g)
except TypeError:
raise TypeError(
"Gradient must be convertible to a Tensor"
" or IndexedSlices, or None: %s" % g)
if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
raise TypeError(
"Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
p = _get_processor(v)
converted_grads_and_vars.append((g, v, p))
converted_grads_and_vars = tuple(converted_grads_and_vars)
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
([str(v) for _, v, _ in converted_grads_and_vars],))
with ops.init_scope():
self._create_slots(var_list)
update_ops = []
with ops.name_scope(name, self._name) as name:
self._prepare()
for grad, var, processor in converted_grads_and_vars:
if grad is None:
continue
# We colocate all ops created in _apply_dense or _apply_sparse
# on the same device as the variable.
# TODO(apassos): figure out how to get the variable name here.
if context.executing_eagerly() or isinstance(
var,
resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
scope_name = ""
else:
scope_name = var.op.name
with ops.name_scope("update_" + scope_name), ops.colocate_with(var):
update_ops.append(processor.update_op(self, grad))
if global_step is None:
apply_updates = self._finish(update_ops, name)
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.colocate_with(global_step):
if isinstance(global_step, resource_variable_ops.ResourceVariable):
# TODO(apassos): the implicit read in assign_add is slow; consider
# making it less so.
apply_updates = resource_variable_ops.assign_add_variable_op(
global_step.handle,
ops.convert_to_tensor(1, dtype=global_step.dtype),
name=name)
else:
apply_updates = state_ops.assign_add(global_step, 1, name=name)
if not context.executing_eagerly():
#.........这里部分代码省略.........
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:101,代码来源:optimizer.py
示例18: _replica_id
def _replica_id():
replica_id = ds_context.get_replica_context().replica_id_in_sync_group
if not isinstance(replica_id, ops.Tensor):
replica_id = constant_op.constant(replica_id)
return replica_id
开发者ID:pyjennings,项目名称:tensorflow,代码行数:5,代码来源:keras_optimizer_v2_test.py
示例19: _merge_call_merge_raises_fn
def _merge_call_merge_raises_fn():
ds_context.get_replica_context().merge_call(_call_merge_raises_fn)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:2,代码来源:strategy_test_lib.py
示例20: _merge_raises_fn
def _merge_raises_fn():
ds_context.get_replica_context().merge_call(_raise_exception_fn)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:2,代码来源:strategy_test_lib.py
注:本文中的tensorflow.python.distribute.distribution_strategy_context.get_replica_context函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论