本文整理汇总了Python中tensorflow.python.distribute.distribution_strategy_context.get_strategy函数的典型用法代码示例。如果您正苦于以下问题:Python get_strategy函数的具体用法?Python get_strategy怎么用?Python get_strategy使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_strategy函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: call_replica_local_fn
def call_replica_local_fn(fn, *args, **kwargs):
"""Call a function that uses replica-local variables.
This function correctly handles calling `fn` in a cross-replica
context.
Arguments:
fn: The function to call.
*args: Positional arguments to the `fn`.
**kwargs: Keyword argument to `fn`.
Returns:
The result of calling `fn`.
"""
# TODO(b/120571621): We want to avoid reductions here since
# since TPUStrategy does not implement replica local variables.
# Remove this hack once we support TPUReplicaLocalVariables.
strategy = None
if 'strategy' in kwargs:
strategy = kwargs.pop('strategy')
else:
if ds_context.get_strategy():
strategy = ds_context.get_strategy()
is_tpu = is_tpu_strategy(strategy)
if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
with strategy.scope():
return strategy.extended.call_for_each_replica(fn, args, kwargs)
return fn(*args, **kwargs)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:distributed_training_utils.py
示例2: add_slot
def add_slot(self, var, slot_name, initializer="zeros"):
"""Add a new slot variable for `var`."""
if slot_name not in self._slot_names:
self._slot_names.append(slot_name)
var_key = _var_key(var)
slot_dict = self._slots.setdefault(var_key, {})
weight = slot_dict.get(slot_name, None)
if weight is None:
if isinstance(initializer, six.string_types) or callable(initializer):
initializer = initializers.get(initializer)
initial_value = functools.partial(
initializer, shape=var.shape, dtype=var.dtype)
else:
initial_value = initializer
strategy = distribute_ctx.get_strategy()
with strategy.extended.colocate_vars_with(var):
weight = tf_variables.Variable(
name="%s/%s" % (var._shared_name, slot_name), # pylint: disable=protected-access
dtype=var.dtype,
trainable=False,
initial_value=initial_value)
backend.track_variable(weight)
slot_dict[slot_name] = weight
self._restore_slot_variable(
slot_name=slot_name, variable=var,
slot_variable=weight)
self._weights.append(weight)
return weight
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:optimizer_v2.py
示例3: create_slot
def create_slot(primary, val, name, colocate_with_primary=True):
"""Create a slot initialized to the given value.
The type of the slot is determined by the given value.
Args:
primary: The primary `Variable` or `Tensor`.
val: A `Tensor` specifying the initial value of the slot.
name: Name to use for the slot variable.
colocate_with_primary: Boolean. If True the slot is located
on the same device as `primary`.
Returns:
A `Variable` object.
"""
# Scope the slot name in the namespace of the primary variable.
# Set "primary.op.name + '/' + name" as default name, so the scope name of
# optimizer can be shared when reuse is True. Meanwhile when reuse is False
# and the same name has been previously used, the scope name will add '_N'
# as suffix for unique identifications.
validate_shape = val.get_shape().is_fully_defined()
if context.executing_eagerly():
prefix = primary._shared_name # pylint: disable=protected-access
else:
prefix = primary.op.name
with variable_scope.variable_scope(None, prefix + "/" + name):
if colocate_with_primary:
distribution_strategy = distribution_strategy_context.get_strategy()
with distribution_strategy.colocate_vars_with(primary):
return _create_slot_var(primary, val, "", validate_shape, None, None)
else:
return _create_slot_var(primary, val, "", validate_shape, None, None)
开发者ID:pyjennings,项目名称:tensorflow,代码行数:32,代码来源:slot_creator.py
示例4: _fused_batch_norm
def _fused_batch_norm(self, inputs, training):
"""Returns the output of fused batch norm."""
beta = self.beta if self.center else self._beta_const
gamma = self.gamma if self.scale else self._gamma_const
def _fused_batch_norm_training():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
epsilon=self.epsilon,
data_format=self._data_format)
def _fused_batch_norm_inference():
return nn.fused_batch_norm(
inputs,
gamma,
beta,
mean=self.moving_mean,
variance=self.moving_variance,
epsilon=self.epsilon,
is_training=False,
data_format=self._data_format)
output, mean, variance = tf_utils.smart_cond(
training, _fused_batch_norm_training, _fused_batch_norm_inference)
if not self._bessels_correction_test_only:
# Remove Bessel's correction to be consistent with non-fused batch norm.
# Note that the variance computed by fused batch norm is
# with Bessel's correction.
sample_size = math_ops.cast(
array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
variance *= factor
training_value = tf_utils.constant_value(training)
if training_value is None:
momentum = tf_utils.smart_cond(training,
lambda: self.momentum,
lambda: 1.0)
else:
momentum = ops.convert_to_tensor(self.momentum)
if training_value or training_value is None:
if distribution_strategy_context.in_cross_replica_context():
strategy = distribution_strategy_context.get_strategy()
mean_update = strategy.extended.update(
self.moving_mean, self._assign_moving_average,
(mean, self.momentum))
variance_update = strategy.extended.update(
self.moving_variance, self._assign_moving_average,
(variance, self.momentum))
else:
mean_update = self._assign_moving_average(self.moving_mean, mean,
momentum)
variance_update = self._assign_moving_average(self.moving_variance,
variance, momentum)
self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True)
return output
开发者ID:gautam1858,项目名称:tensorflow,代码行数:60,代码来源:normalization.py
示例5: scale_loss_for_distribution
def scale_loss_for_distribution(loss_value):
"""Scales and returns the given loss value by the number of replicas."""
num_replicas = (
distribution_strategy_context.get_strategy().num_replicas_in_sync)
if num_replicas > 1:
loss_value *= (1. / num_replicas)
return loss_value
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:losses_utils.py
示例6: _create_non_slot_variable
def _create_non_slot_variable(self, initial_value, name, colocate_with):
"""Add an extra variable, not associated with a slot."""
# Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
eager = context.executing_eagerly()
graph = None if eager else colocate_with.graph
key = (name, graph)
v = self._non_slot_dict.get(key, None)
if v is None:
self._maybe_initialize_trackable()
distribution_strategy = distribute_ctx.get_strategy()
with distribution_strategy.extended.colocate_vars_with(colocate_with):
if eager:
restored_initial_value = self._preload_simple_restoration(
name=name, shape=None)
if restored_initial_value is not None:
initial_value = restored_initial_value
v = variable_scope.variable(
initial_value, name=name, trainable=False,
use_resource=resource_variable_ops.is_resource_variable(
colocate_with))
# Restore this variable by name if necessary, but don't add a
# Trackable dependency. Optimizers return the current graph's
# non-slot variables from _checkpoint_dependencies explicitly rather
# than unconditionally adding dependencies (since there may be multiple
# non-slot variables with the same name in different graphs, trying to
# save all of them would result in errors).
self._handle_deferred_dependencies(name=name, trackable=v)
self._non_slot_dict[key] = v
return v
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:optimizer.py
示例7: _assert_in_default_state
def _assert_in_default_state(t):
t.assertIs(ds_context._get_default_replica_context(),
ds_context.get_replica_context())
t.assertIs(None, ds_context.get_cross_replica_context())
t.assertFalse(ds_context.in_cross_replica_context())
t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
t.assertFalse(ds_context.has_strategy())
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:distribute_lib_test.py
示例8: merge_fn
def merge_fn(dist, s):
self.assertIs(ds_context._get_default_strategy(), dist)
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertIs(dist, ds_context.get_strategy())
self.assertFalse(ds_context.has_strategy())
return "foo_" + s
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:distribute_lib_test.py
示例9: _scale_loss
def _scale_loss(loss_value):
ops.get_default_graph()._is_loss_scaled_by_optimizer = False # pylint: disable=protected-access
if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
if num_replicas > 1:
loss_value *= (1. / num_replicas)
ops.get_default_graph()._is_loss_scaled_by_optimizer = True # pylint: disable=protected-access
return loss_value
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:optimizer.py
示例10: testSetStrategy
def testSetStrategy(self):
_assert_in_default_state(self)
dist = _TestStrategy()
dist2 = _TestStrategy()
ds_context.experimental_set_strategy(dist)
self.assertIs(None, ds_context.get_replica_context())
self.assertIs(dist, ds_context.get_cross_replica_context())
self.assertTrue(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())
self.assertIs(dist, ds_context.get_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="baz"))
ds_context.experimental_set_strategy(dist2)
self.assertIs(dist2, ds_context.get_strategy())
ds_context.experimental_set_strategy(None)
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:19,代码来源:distribute_lib_test.py
示例11: _reduce_weighted_loss
def _reduce_weighted_loss(
weighted_losses, reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE):
"""Reduces the individual weighted loss measurements."""
if reduction == losses_impl.ReductionV2.NONE:
loss = weighted_losses
else:
loss = math_ops.reduce_sum(weighted_losses)
if reduction == losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE:
num_replicas = ( # Used to convert from local to global batch size.
distribution_strategy_context.get_strategy().num_replicas_in_sync)
loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses))
return loss
开发者ID:ziky90,项目名称:tensorflow,代码行数:12,代码来源:losses_utils.py
示例12: _get_tensor
def _get_tensor(self, is_finite):
tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))
if not distribution_strategy_context.has_strategy():
return tensor
def get():
rep_id = (distribution_strategy_context.get_replica_context()
.replica_id_in_sync_group)
return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor,
lambda: 1.)
distribution = distribution_strategy_context.get_strategy()
return distribution.extended.call_for_each_replica(get)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:12,代码来源:loss_scale_test.py
示例13: run_fn
def run_fn():
replica_context = ds_context.get_replica_context()
self.assertTrue(replica_context is not None)
self.assertIs(None, ds_context.get_cross_replica_context())
self.assertFalse(ds_context.in_cross_replica_context())
self.assertTrue(ds_context.has_strategy())
self.assertIs(dist, ds_context.get_strategy())
self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="bar"))
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py
示例14: testScopeDeviceNestingError
def testScopeDeviceNestingError(self):
_assert_in_default_state(self)
dist = _TestStrategy()
# Open a device scope with dist.scope().
dist.extended._default_device = "/device:GPU:0"
scope = dist.scope()
scope.__enter__()
self.assertIs(dist, ds_context.get_strategy())
with ops.device("/device:CPU:0"):
with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
scope.__exit__(None, None, None)
scope.__exit__(None, None, None)
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py
示例15: testScopeVarScopeNestingError
def testScopeVarScopeNestingError(self):
# We create a new graph here to simplify clean-up, since the error
# we are triggering happens in the middle of scope.__exit__() and
# leaves us in a weird state.
with ops.Graph().as_default():
_assert_in_default_state(self)
dist = _TestStrategy()
scope = dist.scope()
scope.__enter__()
self.assertIs(dist, ds_context.get_strategy())
with variable_scope.variable_scope("AA"):
with self.assertRaisesRegexp(RuntimeError,
"Variable scope nesting error"):
scope.__exit__(None, None, None)
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:distribute_lib_test.py
示例16: testScopeVarCreatorNestingError
def testScopeVarCreatorNestingError(self):
def creator(next_creator, **kwargs):
return next_creator(**kwargs)
_assert_in_default_state(self)
dist = _TestStrategy()
scope = dist.scope()
scope.__enter__()
self.assertIs(dist, ds_context.get_strategy())
with variable_scope.variable_creator_scope(creator):
with self.assertRaisesRegexp(RuntimeError,
"Variable creator scope nesting error"):
scope.__exit__(None, None, None)
scope.__exit__(None, None, None)
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:distribute_lib_test.py
示例17: testScope
def testScope(self):
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
self.assertIs(None, distribution_strategy_context.get_replica_context())
self.assertIs(dist,
distribution_strategy_context.get_cross_replica_context())
self.assertTrue(distribution_strategy_context.in_cross_replica_context())
self.assertTrue(distribution_strategy_context.has_strategy())
self.assertIs(dist,
distribution_strategy_context.get_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
self.assertDictEqual(expected_value,
variable_scope.variable(1.0, name="baz"))
_assert_in_default_state(self)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:17,代码来源:distribute_lib_test.py
示例18: set_last_step_output
def set_last_step_output(self, name, output, reduce_op=None):
"""Set `output` with `name` to be outputted from the last step.
Args:
name: String, name to identify the output. Doesn't need to match tensor
name.
output: The tensors that should be outputted with `name`. See below for
actual types supported.
reduce_op: Reduction method to use to reduce outputs from multiple
replicas. Required if `set_last_step_output` is called in a replica
context. Optional in cross_replica_context.
When present, the outputs from all the replicas are reduced using the
current distribution strategy's `reduce` method. Hence, the type of
`output` must be what's supported by the corresponding `reduce` method.
For e.g. if using MirroredStrategy and reduction is set, output
must be a `PerReplica` value.
The reduce method is also recorded in a dictionary
`_last_step_outputs_reduce_ops` for later interpreting of the
outputs as already reduced or not.
"""
if distribution_strategy_context.in_cross_replica_context():
self._last_step_outputs_reduce_ops[name] = reduce_op
if reduce_op is None:
self._last_step_outputs[name] = output
else:
distribution = distribution_strategy_context.get_strategy()
self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
axis=None)
else:
assert reduce_op is not None
def merge_fn(distribution, value):
self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
axis=None)
# Setting this inside the `merge_fn` because all replicas share the same
# context object, so it's more robust to set it only once (even if all
# the replicas are trying to set the same value).
self._last_step_outputs_reduce_ops[name] = reduce_op
distribution_strategy_context.get_replica_context().merge_call(
merge_fn, args=(output,))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:40,代码来源:input_lib.py
示例19: testSameScopeNesting
def testSameScopeNesting(self):
_assert_in_default_state(self)
dist = _TestStrategy()
scope_a = dist.scope()
with scope_a:
self.assertIs(dist, ds_context.get_strategy())
scope_b = dist.scope()
with scope_b:
self.assertIs(dist, ds_context.get_strategy())
with scope_a:
self.assertIs(dist, ds_context.get_strategy())
self.assertIs(dist, ds_context.get_strategy())
self.assertIs(dist, ds_context.get_strategy())
dist2 = _TestStrategy()
scope2 = dist2.scope()
with self.assertRaisesRegexp(
RuntimeError,
"Mixing different tf.distribute.Strategy objects"):
with scope2:
pass
_assert_in_default_state(self)
with scope_b:
self.assertIs(dist, ds_context.get_strategy())
_assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:24,代码来源:distribute_lib_test.py
示例20: _scale_loss
def _scale_loss(loss_value):
if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
if num_replicas > 1:
loss_value *= (1. / num_replicas)
return loss_value
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:6,代码来源:optimizer_v2.py
注:本文中的tensorflow.python.distribute.distribution_strategy_context.get_strategy函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论