• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python distribution_strategy_context.get_strategy函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.distribute.distribution_strategy_context.get_strategy函数的典型用法代码示例。如果您正苦于以下问题:Python get_strategy函数的具体用法?Python get_strategy怎么用?Python get_strategy使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_strategy函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: call_replica_local_fn

def call_replica_local_fn(fn, *args, **kwargs):
  """Call a function that uses replica-local variables.

  This function correctly handles calling `fn` in a cross-replica
  context.

  Arguments:
    fn: The function to call.
    *args: Positional arguments to the `fn`.
    **kwargs: Keyword argument to `fn`.

  Returns:
    The result of calling `fn`.
  """
  # TODO(b/120571621): We want to avoid reductions here since
  # since TPUStrategy does not implement replica local variables.
  # Remove this hack once we support TPUReplicaLocalVariables.
  strategy = None
  if 'strategy' in kwargs:
    strategy = kwargs.pop('strategy')
  else:
    if ds_context.get_strategy():
      strategy = ds_context.get_strategy()

  is_tpu = is_tpu_strategy(strategy)
  if ((not is_tpu) and strategy and ds_context.in_cross_replica_context()):
    with strategy.scope():
      return strategy.extended.call_for_each_replica(fn, args, kwargs)
  return fn(*args, **kwargs)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:distributed_training_utils.py


示例2: add_slot

 def add_slot(self, var, slot_name, initializer="zeros"):
   """Add a new slot variable for `var`."""
   if slot_name not in self._slot_names:
     self._slot_names.append(slot_name)
   var_key = _var_key(var)
   slot_dict = self._slots.setdefault(var_key, {})
   weight = slot_dict.get(slot_name, None)
   if weight is None:
     if isinstance(initializer, six.string_types) or callable(initializer):
       initializer = initializers.get(initializer)
       initial_value = functools.partial(
           initializer, shape=var.shape, dtype=var.dtype)
     else:
       initial_value = initializer
     strategy = distribute_ctx.get_strategy()
     with strategy.extended.colocate_vars_with(var):
       weight = tf_variables.Variable(
           name="%s/%s" % (var._shared_name, slot_name),  # pylint: disable=protected-access
           dtype=var.dtype,
           trainable=False,
           initial_value=initial_value)
     backend.track_variable(weight)
     slot_dict[slot_name] = weight
     self._restore_slot_variable(
         slot_name=slot_name, variable=var,
         slot_variable=weight)
     self._weights.append(weight)
   return weight
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:optimizer_v2.py


示例3: create_slot

def create_slot(primary, val, name, colocate_with_primary=True):
  """Create a slot initialized to the given value.

  The type of the slot is determined by the given value.

  Args:
    primary: The primary `Variable` or `Tensor`.
    val: A `Tensor` specifying the initial value of the slot.
    name: Name to use for the slot variable.
    colocate_with_primary: Boolean.  If True the slot is located
      on the same device as `primary`.

  Returns:
    A `Variable` object.
  """
  # Scope the slot name in the namespace of the primary variable.
  # Set "primary.op.name + '/' + name" as default name, so the scope name of
  # optimizer can be shared when reuse is True. Meanwhile when reuse is False
  # and the same name has been previously used, the scope name will add '_N'
  # as suffix for unique identifications.
  validate_shape = val.get_shape().is_fully_defined()
  if context.executing_eagerly():
    prefix = primary._shared_name  # pylint: disable=protected-access
  else:
    prefix = primary.op.name
  with variable_scope.variable_scope(None, prefix + "/" + name):
    if colocate_with_primary:
      distribution_strategy = distribution_strategy_context.get_strategy()
      with distribution_strategy.colocate_vars_with(primary):
        return _create_slot_var(primary, val, "", validate_shape, None, None)
    else:
      return _create_slot_var(primary, val, "", validate_shape, None, None)
开发者ID:pyjennings,项目名称:tensorflow,代码行数:32,代码来源:slot_creator.py


示例4: _fused_batch_norm

  def _fused_batch_norm(self, inputs, training):
    """Returns the output of fused batch norm."""
    beta = self.beta if self.center else self._beta_const
    gamma = self.gamma if self.scale else self._gamma_const

    def _fused_batch_norm_training():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          epsilon=self.epsilon,
          data_format=self._data_format)

    def _fused_batch_norm_inference():
      return nn.fused_batch_norm(
          inputs,
          gamma,
          beta,
          mean=self.moving_mean,
          variance=self.moving_variance,
          epsilon=self.epsilon,
          is_training=False,
          data_format=self._data_format)

    output, mean, variance = tf_utils.smart_cond(
        training, _fused_batch_norm_training, _fused_batch_norm_inference)
    if not self._bessels_correction_test_only:
      # Remove Bessel's correction to be consistent with non-fused batch norm.
      # Note that the variance computed by fused batch norm is
      # with Bessel's correction.
      sample_size = math_ops.cast(
          array_ops.size(inputs) / array_ops.size(variance), variance.dtype)
      factor = (sample_size - math_ops.cast(1.0, variance.dtype)) / sample_size
      variance *= factor

    training_value = tf_utils.constant_value(training)
    if training_value is None:
      momentum = tf_utils.smart_cond(training,
                                     lambda: self.momentum,
                                     lambda: 1.0)
    else:
      momentum = ops.convert_to_tensor(self.momentum)
    if training_value or training_value is None:
      if distribution_strategy_context.in_cross_replica_context():
        strategy = distribution_strategy_context.get_strategy()
        mean_update = strategy.extended.update(
            self.moving_mean, self._assign_moving_average,
            (mean, self.momentum))
        variance_update = strategy.extended.update(
            self.moving_variance, self._assign_moving_average,
            (variance, self.momentum))
      else:
        mean_update = self._assign_moving_average(self.moving_mean, mean,
                                                  momentum)
        variance_update = self._assign_moving_average(self.moving_variance,
                                                      variance, momentum)
      self.add_update(mean_update, inputs=True)
      self.add_update(variance_update, inputs=True)

    return output
开发者ID:gautam1858,项目名称:tensorflow,代码行数:60,代码来源:normalization.py


示例5: scale_loss_for_distribution

def scale_loss_for_distribution(loss_value):
  """Scales and returns the given loss value by the number of replicas."""
  num_replicas = (
      distribution_strategy_context.get_strategy().num_replicas_in_sync)
  if num_replicas > 1:
    loss_value *= (1. / num_replicas)
  return loss_value
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:losses_utils.py


示例6: _create_non_slot_variable

  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_trackable()
      distribution_strategy = distribute_ctx.get_strategy()
      with distribution_strategy.extended.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(
            initial_value, name=name, trainable=False,
            use_resource=resource_variable_ops.is_resource_variable(
                colocate_with))
      # Restore this variable by name if necessary, but don't add a
      # Trackable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, trackable=v)
      self._non_slot_dict[key] = v

    return v
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:optimizer.py


示例7: _assert_in_default_state

def _assert_in_default_state(t):
  t.assertIs(ds_context._get_default_replica_context(),
             ds_context.get_replica_context())
  t.assertIs(None, ds_context.get_cross_replica_context())
  t.assertFalse(ds_context.in_cross_replica_context())
  t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
  t.assertFalse(ds_context.has_strategy())
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:distribute_lib_test.py


示例8: merge_fn

 def merge_fn(dist, s):
   self.assertIs(ds_context._get_default_strategy(), dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertFalse(ds_context.has_strategy())
   return "foo_" + s
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:distribute_lib_test.py


示例9: _scale_loss

 def _scale_loss(loss_value):
   ops.get_default_graph()._is_loss_scaled_by_optimizer = False  # pylint: disable=protected-access
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
       ops.get_default_graph()._is_loss_scaled_by_optimizer = True  # pylint: disable=protected-access
   return loss_value
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:optimizer.py


示例10: testSetStrategy

 def testSetStrategy(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   dist2 = _TestStrategy()
   ds_context.experimental_set_strategy(dist)
   self.assertIs(None, ds_context.get_replica_context())
   self.assertIs(dist, ds_context.get_cross_replica_context())
   self.assertTrue(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   expected_value = _get_test_variable(
       "baz", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="baz"))
   ds_context.experimental_set_strategy(dist2)
   self.assertIs(dist2, ds_context.get_strategy())
   ds_context.experimental_set_strategy(None)
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:19,代码来源:distribute_lib_test.py


示例11: _reduce_weighted_loss

def _reduce_weighted_loss(
    weighted_losses, reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE):
  """Reduces the individual weighted loss measurements."""
  if reduction == losses_impl.ReductionV2.NONE:
    loss = weighted_losses
  else:
    loss = math_ops.reduce_sum(weighted_losses)
    if reduction == losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE:
      num_replicas = (  # Used to convert from local to global batch size.
          distribution_strategy_context.get_strategy().num_replicas_in_sync)
      loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses))
  return loss
开发者ID:ziky90,项目名称:tensorflow,代码行数:12,代码来源:losses_utils.py


示例12: _get_tensor

  def _get_tensor(self, is_finite):
    tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))

    if not distribution_strategy_context.has_strategy():
      return tensor
    def get():
      rep_id = (distribution_strategy_context.get_replica_context()
                .replica_id_in_sync_group)
      return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor,
                                   lambda: 1.)
    distribution = distribution_strategy_context.get_strategy()
    return distribution.extended.call_for_each_replica(get)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:12,代码来源:loss_scale_test.py


示例13: run_fn

 def run_fn():
   replica_context = ds_context.get_replica_context()
   self.assertTrue(replica_context is not None)
   self.assertIs(None, ds_context.get_cross_replica_context())
   self.assertFalse(ds_context.in_cross_replica_context())
   self.assertTrue(ds_context.has_strategy())
   self.assertIs(dist, ds_context.get_strategy())
   self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
   expected_value = _get_test_variable(
       "bar", variable_scope.VariableSynchronization.AUTO,
       variable_scope.VariableAggregation.NONE)
   self.assertDictEqual(expected_value,
                        variable_scope.variable(1.0, name="bar"))
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py


示例14: testScopeDeviceNestingError

 def testScopeDeviceNestingError(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   # Open a device scope with dist.scope().
   dist.extended._default_device = "/device:GPU:0"
   scope = dist.scope()
   scope.__enter__()
   self.assertIs(dist, ds_context.get_strategy())
   with ops.device("/device:CPU:0"):
     with self.assertRaisesRegexp(RuntimeError, "Device scope nesting error"):
       scope.__exit__(None, None, None)
   scope.__exit__(None, None, None)
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:13,代码来源:distribute_lib_test.py


示例15: testScopeVarScopeNestingError

 def testScopeVarScopeNestingError(self):
   # We create a new graph here to simplify clean-up, since the error
   # we are triggering happens in the middle of scope.__exit__() and
   # leaves us in a weird state.
   with ops.Graph().as_default():
     _assert_in_default_state(self)
     dist = _TestStrategy()
     scope = dist.scope()
     scope.__enter__()
     self.assertIs(dist, ds_context.get_strategy())
     with variable_scope.variable_scope("AA"):
       with self.assertRaisesRegexp(RuntimeError,
                                    "Variable scope nesting error"):
         scope.__exit__(None, None, None)
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:15,代码来源:distribute_lib_test.py


示例16: testScopeVarCreatorNestingError

  def testScopeVarCreatorNestingError(self):

    def creator(next_creator, **kwargs):
      return next_creator(**kwargs)

    _assert_in_default_state(self)
    dist = _TestStrategy()
    scope = dist.scope()
    scope.__enter__()
    self.assertIs(dist, ds_context.get_strategy())
    with variable_scope.variable_creator_scope(creator):
      with self.assertRaisesRegexp(RuntimeError,
                                   "Variable creator scope nesting error"):
        scope.__exit__(None, None, None)
    scope.__exit__(None, None, None)
    _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:16,代码来源:distribute_lib_test.py


示例17: testScope

 def testScope(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   with dist.scope():
     self.assertIs(None, distribution_strategy_context.get_replica_context())
     self.assertIs(dist,
                   distribution_strategy_context.get_cross_replica_context())
     self.assertTrue(distribution_strategy_context.in_cross_replica_context())
     self.assertTrue(distribution_strategy_context.has_strategy())
     self.assertIs(dist,
                   distribution_strategy_context.get_strategy())
     expected_value = _get_test_variable(
         "baz", variable_scope.VariableSynchronization.AUTO,
         variable_scope.VariableAggregation.NONE)
     self.assertDictEqual(expected_value,
                          variable_scope.variable(1.0, name="baz"))
   _assert_in_default_state(self)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:17,代码来源:distribute_lib_test.py


示例18: set_last_step_output

  def set_last_step_output(self, name, output, reduce_op=None):
    """Set `output` with `name` to be outputted from the last step.

    Args:
      name: String, name to identify the output. Doesn't need to match tensor
        name.
      output: The tensors that should be outputted with `name`. See below for
        actual types supported.
      reduce_op: Reduction method to use to reduce outputs from multiple
        replicas. Required if `set_last_step_output` is called in a replica
        context. Optional in cross_replica_context.
        When present, the outputs from all the replicas are reduced using the
        current distribution strategy's `reduce` method. Hence, the type of
        `output` must be what's supported by the corresponding `reduce` method.
        For e.g. if using MirroredStrategy and reduction is set, output
        must be a `PerReplica` value.
        The reduce method is also recorded in a dictionary
        `_last_step_outputs_reduce_ops` for later interpreting of the
        outputs as already reduced or not.
    """
    if distribution_strategy_context.in_cross_replica_context():
      self._last_step_outputs_reduce_ops[name] = reduce_op
      if reduce_op is None:
        self._last_step_outputs[name] = output
      else:
        distribution = distribution_strategy_context.get_strategy()
        self._last_step_outputs[name] = distribution.reduce(reduce_op, output,
                                                            axis=None)
    else:
      assert reduce_op is not None
      def merge_fn(distribution, value):
        self._last_step_outputs[name] = distribution.reduce(reduce_op, value,
                                                            axis=None)
        # Setting this inside the `merge_fn` because all replicas share the same
        # context object, so it's more robust to set it only once (even if all
        # the replicas are trying to set the same value).
        self._last_step_outputs_reduce_ops[name] = reduce_op

      distribution_strategy_context.get_replica_context().merge_call(
          merge_fn, args=(output,))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:40,代码来源:input_lib.py


示例19: testSameScopeNesting

 def testSameScopeNesting(self):
   _assert_in_default_state(self)
   dist = _TestStrategy()
   scope_a = dist.scope()
   with scope_a:
     self.assertIs(dist, ds_context.get_strategy())
     scope_b = dist.scope()
     with scope_b:
       self.assertIs(dist, ds_context.get_strategy())
       with scope_a:
         self.assertIs(dist, ds_context.get_strategy())
       self.assertIs(dist, ds_context.get_strategy())
     self.assertIs(dist, ds_context.get_strategy())
     dist2 = _TestStrategy()
     scope2 = dist2.scope()
     with self.assertRaisesRegexp(
         RuntimeError,
         "Mixing different tf.distribute.Strategy objects"):
       with scope2:
         pass
   _assert_in_default_state(self)
   with scope_b:
     self.assertIs(dist, ds_context.get_strategy())
   _assert_in_default_state(self)
开发者ID:aritratony,项目名称:tensorflow,代码行数:24,代码来源:distribute_lib_test.py


示例20: _scale_loss

 def _scale_loss(loss_value):
   if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
     num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
     if num_replicas > 1:
       loss_value *= (1. / num_replicas)
   return loss_value
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:6,代码来源:optimizer_v2.py



注:本文中的tensorflow.python.distribute.distribution_strategy_context.get_strategy函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap