• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python distribute_coordinator_context.get_current_worker_context函数代码示例 ...

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.distribute.distribute_coordinator_context.get_current_worker_context函数的典型用法代码示例。如果您正苦于以下问题:Python get_current_worker_context函数的具体用法?Python get_current_worker_context怎么用?Python get_current_worker_context使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_current_worker_context函数的17个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _worker_fn

  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    context = dc_context.get_current_worker_context()
    _init_run_config_from_worker_context(local_estimator._config, context)
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    # In the standalone client, we don't need to run hooks on all threads
    # because logging hooks on all threads may be too much on the screen; also
    # tensor passed to one hook can only be fetched with the graph where the
    # tensor is defined. Other hooks such as checkpointing hooks will added by
    # MonitoredTrainingSession.
    # TODO(yuefengz): Is there a hook that does need to run on all threads in
    # standalone client mode?
    if (run_config._distribute_coordinator_mode ==  # pylint: disable=protected-access
        dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
      hooks = list(train_spec.hooks)
    else:
      hooks = []

    # Prevent estimator.train from calling distribute coordinator again. This
    # function calls estimator.train which will use distribute coordinator path
    # again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access
    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=hooks)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:32,代码来源:estimator_training.py


示例2: configure_and_create_session

def configure_and_create_session(distribution_strategy):
  """Configure session config and create a session with it."""
  # TODO(priyag): Throw error if a session already exists.
  session_config = K.get_default_session_config()

  if is_tpu_strategy(distribution_strategy):
    # TODO(priyag, yuefengz): Remove this workaround when Distribute
    # Coordinator is integrated with keras and we can create a session from
    # there.
    distribution_strategy.configure(session_config)
    master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
    session = session_module.Session(config=session_config, target=master)
  else:
    worker_context = dc_context.get_current_worker_context()
    if worker_context:
      dc_session_config = worker_context.session_config
      # Merge the default session config to the one from distribute coordinator,
      # which is fine for now since they don't have conflicting configurations.
      dc_session_config.MergeFrom(session_config)
      session = session_module.Session(
          config=dc_session_config, target=worker_context.master_target)
    else:
      session = session_module.Session(config=session_config)

  K.set_session(session)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:25,代码来源:distributed_training_utils.py


示例3: filter_distributed_callbacks

def filter_distributed_callbacks(callbacks_list):
  """Filter Callbacks based on the worker context when running multi-worker.

  Arguments:
    callbacks_list: A list of `Callback` instances.

  Returns:
    The list of `Callback` instances that should be run on this worker.
  """

  if not K.in_multi_worker_mode():
    raise ValueError(
        'filter_distributed_callbacks() should only be called when Keras '
        'is in multi worker mode.')

  worker_context = dc_context.get_current_worker_context()
  callbacks_list = callbacks_list or []
  if not [
      c for c in callbacks_list if isinstance(c, callbacks.ModelCheckpoint)
  ]:
    # TODO(rchao): Consider providing a ModelCheckpoint here if the user
    # fails to.
    logging.warning('ModelCheckpoint callback is not provided. '
                    'Workers will need to restart training if any fails.')
  # TODO(rchao): Add similar warning for restoring callback (to be designed).

  if callbacks_list is None or worker_context.is_chief:
    return callbacks_list

  # Some Callbacks should only run on the chief worker.
  return [
      callback for callback in callbacks_list if not callback._chief_worker_only
  ]  # pylint: disable=protected-access
开发者ID:kylin9872,项目名称:tensorflow,代码行数:33,代码来源:distributed_training_utils.py


示例4: _between_graph_with_monitored_session

  def _between_graph_with_monitored_session(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with ops.device("/job:ps/task:0"):
      # TODO(yuefengz): investigate why not using resource variable will make
      # the test flaky.
      x = variable_scope.get_variable("xx", initializer=10.0, use_resource=True)
    with ops.device("/job:ps/task:1"):
      y = variable_scope.get_variable("yy", initializer=20.0, use_resource=True)

    x_add = x.assign_add(2.0)
    y_sub = y.assign_sub(2.0)
    train_op = control_flow_ops.group([x_add, y_sub])

    # The monitored session will run init or ready ops.
    with monitored_session.MonitoredSession() as sess:
      sess.run(train_op)

      # Synchronize workers after one step to make sure they all have finished
      # training.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        self._barrier.wait()

      x_val, y_val = sess.run([x, y])

    self.assertEqual(x_val, 16.0)
    self.assertEqual(y_val, 14.0)
    if x_val == 16.0 and y_val == 14.0:
      with self._lock:
        self._result_correct += 1
开发者ID:aritratony,项目名称:tensorflow,代码行数:32,代码来源:distribute_coordinator_test.py


示例5: __enter__

 def __enter__(self):
   old_context = distribute_coordinator_context.get_current_worker_context()
   if old_context:
     raise ValueError(
         "You cannot run distribute coordinator in a `worker_fn`.\t" +
         self._debug_message())
   # pylint: disable=protected-access
   distribute_coordinator_context._worker_context.current = self
开发者ID:terrytangyuan,项目名称:tensorflow,代码行数:8,代码来源:distribute_coordinator.py


示例6: _worker_fn

 def _worker_fn(self, strategy):
   worker_context = distribute_coordinator_context.get_current_worker_context()
   session_config = worker_context._session_config
   self._device_filters.extend(session_config.device_filters)
   self._intra_op_parallelism_threads = (
       session_config.intra_op_parallelism_threads)
   self._inter_op_parallelism_threads = (
       session_config.inter_op_parallelism_threads)
   return MockServer()
开发者ID:aritratony,项目名称:tensorflow,代码行数:9,代码来源:distribute_coordinator_test.py


示例7: init_restore_or_wait_for_variables

def init_restore_or_wait_for_variables():
  """Initialize or restore variables or wait for variables to be initialized."""
  session = K._get_session()  # pylint: disable=protected-access
  worker_context = dc_context.get_current_worker_context()
  if not worker_context or worker_context.experimental_should_init:
    # TODO(yuefengz): if checkpoints exist, restore from checkpoint.
    K._initialize_variables(session)  # pylint: disable=protected-access
  else:
    _wait_for_variable_initialization(session)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:9,代码来源:distributed_training_utils.py


示例8: _eval_fn

  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    local_estimator._eval_distribution = strategy

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
开发者ID:AnishShah,项目名称:tensorflow,代码行数:11,代码来源:estimator_training.py


示例9: _worker_fn

  def _worker_fn(strategy):
    """Function for worker task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._train_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    local_estimator._train_distribution = strategy
    # pylint: enable=protected-access

    local_estimator.train(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:14,代码来源:estimator_training.py


示例10: worker_fn

  def worker_fn(strategy):
    with ops.Graph().as_default():
      batch_size = 64
      steps = 2

      with strategy.scope():
        train_ds, _ = _mnist_synthetic_dataset(batch_size, steps)
        model = _clone_and_build_model(orig_model, strategy)

        orig_loss, orig_acc = model.evaluate(train_ds, steps=steps)

        # Workaround for the metrics issue (b/122928955) in async training. This
        # can only be used in standalone client mode.
        dc_context.get_current_worker_context().wait_for_other_workers()

        model.fit(x=train_ds, epochs=2, steps_per_epoch=steps)

        dc_context.get_current_worker_context().wait_for_other_workers()

        trained_loss, trained_acc = model.evaluate(train_ds, steps=steps)

      test_obj.assertLessEqual(trained_loss, orig_loss)
      test_obj.assertGreaterEqual(trained_acc, orig_acc)
开发者ID:aritratony,项目名称:tensorflow,代码行数:23,代码来源:multi_worker_test.py


示例11: _between_graph_worker_fn

  def _between_graph_worker_fn(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with self._test_session(target=context.master_target) as sess:
      with ops.device("/job:ps/task:0"):
        # TODO(yuefengz): investigate why not using resource variable will make
        # the test flaky.
        x = variable_scope.get_variable(
            "x", initializer=10.0, use_resource=True)
      with ops.device("/job:ps/task:1"):
        y = variable_scope.get_variable(
            "y", initializer=20.0, use_resource=True)

      x_add = x.assign_add(2.0)
      y_sub = y.assign_sub(2.0)
      train_op = control_flow_ops.group([x_add, y_sub])

      if context.is_chief:
        self.evaluate(variables.global_variables_initializer())

      # Synchronize workers after initializaton.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        while True:
          uninit_vars = sess.run(variables.report_uninitialized_variables())
          # pylint: disable=g-explicit-length-test
          if len(uninit_vars) == 0:
            break

      sess.run(train_op)

      # Synchronize workers after one step to make sure they all have finished
      # training.
      if context.has_barrier:
        context.wait_for_other_workers()
      else:
        self._barrier.wait()

      x_val, y_val = sess.run([x, y])

      self.assertEqual(x_val, 16.0)
      self.assertEqual(y_val, 14.0)
      if x_val == 16.0 and y_val == 14.0:
        with self._lock:
          self._result_correct += 1
开发者ID:aritratony,项目名称:tensorflow,代码行数:46,代码来源:distribute_coordinator_test.py


示例12: _eval_fn

  def _eval_fn(strategy):
    """Function for evaluator task."""
    local_estimator = copy.deepcopy(estimator)
    # pylint: disable=protected-access
    local_estimator._config._eval_distribute = strategy
    _init_run_config_from_worker_context(
        local_estimator._config, dc_context.get_current_worker_context())
    logging.info('Updated config: %s', str(vars(local_estimator._config)))
    local_estimator._eval_distribution = strategy

    # Prevent estimator.evaluate from calling distribute coordinator again. This
    # function calls estimator.evaluate which will use distribute coordinator
    # path again if `_distribute_coordinator_mode` is set.
    local_estimator._config._distribute_coordinator_mode = None  # pylint: disable=protected-access

    executor = executor_cls(local_estimator, train_spec, eval_spec)
    executor._start_continuous_evaluation()
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:17,代码来源:estimator_training.py


示例13: _dump_strategy_property

  def _dump_strategy_property(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)

    self.assertEqual(context._strategy.should_init, strategy.should_init)
    self.assertEqual(context.should_checkpoint, strategy.should_checkpoint)
    self.assertEqual(context.should_save_summary, strategy.should_save_summary)

    task_type = str(context.task_type)
    task_id = context.task_id or 0
    with self._lock:
      if task_type not in self._strategy_property:
        self._strategy_property[task_type] = []
      while len(self._strategy_property[task_type]) <= task_id:
        self._strategy_property[task_type].append(None)
      self._strategy_property[task_type][task_id] = (
          context._strategy.should_init, context.should_checkpoint,
          context.should_save_summary)
开发者ID:mrlittlepig,项目名称:tensorflow,代码行数:18,代码来源:distribute_coordinator_test.py


示例14: _in_graph_worker_fn

  def _in_graph_worker_fn(self, strategy):
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    with self._test_session(target=context.master_target) as sess:
      xs = []
      expected = 0.0
      for i in range(context.num_workers):
        with ops.device("/job:worker/task:%d" % i):
          x = variable_scope.get_variable("x_%d" % i, initializer=10.0)
          x_add = x.assign_add(float(i))
          xs.append(x_add)
          expected += i + 10.0

      with ops.device("/job:worker/task:0"):
        result = math_ops.add_n(xs)

      self.evaluate(variables.global_variables_initializer())
      result_value = sess.run(result)
    self.assertEqual(result_value, expected)
    if result_value == expected:
      self._result_correct += 1
开发者ID:aritratony,项目名称:tensorflow,代码行数:21,代码来源:distribute_coordinator_test.py


示例15: _dump_worker_context

  def _dump_worker_context(self, strategy):
    """Dumps the propoerties of each worker context.

    It dumps the context properties to a dict mapping from task_type to a list
    of tuples of master_target, num_workers, is_chief and distribute_mode, where
    the list is indexed by the task_id.

    Args:
      strategy: a `DistributionStrategy` object.
    """
    context = distribute_coordinator_context.get_current_worker_context()
    self.assertTrue(context is not None)
    task_type = str(context.task_type)
    task_id = context.task_id or 0
    with self._lock:
      if task_type not in self._worker_context:
        self._worker_context[task_type] = []
      while len(self._worker_context[task_type]) <= task_id:
        self._worker_context[task_type].append(None)
      self._worker_context[task_type][task_id] = (context.master_target,
                                                  context.num_workers,
                                                  context.is_chief,
                                                  context.distributed_mode)
开发者ID:aritratony,项目名称:tensorflow,代码行数:23,代码来源:distribute_coordinator_test.py


示例16: is_current_worker_chief

def is_current_worker_chief():
  return dc_context.get_current_worker_context().is_chief
开发者ID:aritratony,项目名称:tensorflow,代码行数:2,代码来源:distributed_training_utils.py


示例17: on_train_begin

 def on_train_begin(self, logs):
   if not dc_context.get_current_worker_context().is_chief:
     # Non-chief workers shouldn't run this callback.
     self.filtered_correctly = False
开发者ID:aritratony,项目名称:tensorflow,代码行数:4,代码来源:multi_worker_callback_test.py



注:本文中的tensorflow.python.distribute.distribute_coordinator_context.get_current_worker_context函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap