• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python training_util.get_global_step函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.training_util.get_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_global_step函数的具体用法?Python get_global_step怎么用?Python get_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_global_step函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test_get_global_step

 def test_get_global_step(self):
   with ops.Graph().as_default() as g:
     self.assertIsNone(training_util.get_global_step())
     variables.VariableV1(
         0,
         trainable=False,
         dtype=dtypes.int32,
         name=ops.GraphKeys.GLOBAL_STEP)
     self._assert_global_step(
         training_util.get_global_step(), expected_dtype=dtypes.int32)
   self._assert_global_step(
       training_util.get_global_step(g), expected_dtype=dtypes.int32)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:training_util_test.py


示例2: _ModelFn

    def _ModelFn(features, labels, mode):
      if is_training:
        logits_out = self._BuildGraph(features)
      else:
        graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
        logits_out = importer.import_graph_def(
            graph_def,
            input_map={INPUT_NODE_NAME: features},
            return_elements=[OUTPUT_NODE_NAME + ':0'],
            name='')[0]

      loss = losses.sparse_softmax_cross_entropy(
          labels=labels, logits=logits_out)
      summary.scalar('loss', loss)

      classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
      accuracy = metrics.accuracy(
          labels=labels, predictions=classes_out, name='acc_op')
      summary.scalar('accuracy', accuracy[1])

      if mode == ModeKeys.EVAL:
        return EstimatorSpec(
            mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
      elif mode == ModeKeys.TRAIN:
        optimizer = AdamOptimizer(learning_rate=1e-2)
        train_op = optimizer.minimize(loss, global_step=get_global_step())
        return EstimatorSpec(mode, loss=loss, train_op=train_op)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:quantization_mnist_test.py


示例3: _train_op_fn

  def _train_op_fn(loss):
    """Returns the op to optimize the loss."""
    train_ops = []
    global_step = training_util.get_global_step()
    if dnn_logits is not None:
      train_ops.append(
          dnn_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=dnn_parent_scope)))
    if linear_logits is not None:
      train_ops.append(
          linear_optimizer.minimize(
              loss,
              var_list=ops.get_collection(
                  ops.GraphKeys.TRAINABLE_VARIABLES,
                  scope=linear_parent_scope)))

    train_op = control_flow_ops.group(*train_ops)
    with ops.control_dependencies([train_op]):
      with ops.colocate_with(global_step):
        return state_ops.assign_add(global_step, 1)

    return head.create_estimator_spec(
        features=features,
        mode=mode,
        labels=labels,
        train_op_fn=_train_op_fn,
        logits=logits)
开发者ID:m-colombo,项目名称:tensorflow,代码行数:30,代码来源:dnn_linear_combined.py


示例4: record_summaries_every_n_global_steps

def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  collection_ref[:] = [training_util.get_global_step() % n == 0]
  yield
  collection_ref[:] = old
开发者ID:benoitsteiner,项目名称:tensorflow-opencl,代码行数:7,代码来源:summary_ops.py


示例5: before_run

 def before_run(self, run_context):
   loss = (self.loss_op if self.loss_op is not None else
           run_context.session.graph.get_operation_by_name(
               LOSS_NAME).outputs[0])
   return session_run_hook.SessionRunArgs(
       {'global_step': training_util.get_global_step(),
        'current_loss': loss})
开发者ID:AnishShah,项目名称:tensorflow,代码行数:7,代码来源:random_forest.py


示例6: __init__

 def __init__(self):
   global_step = training_util.get_global_step()
   if global_step:
     self._global_step_incr_op = state_ops.assign_add(
         global_step, 1, name="global_step_incr").op
   else:
     self._global_step_incr_op = None
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:7,代码来源:wals.py


示例7: begin

 def begin(self):
   self._last_reported_time = None
   self._last_reported_step = None
   self._global_step_tensor = training_util.get_global_step()
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use StepCounterHook.")
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py


示例8: function

 def function(tag, scope):
   if bad_color is None:
     bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
   gen_summary_ops.write_image_summary(
       context.context().summary_writer_resource,
       training_util.get_global_step(), tag, tensor, bad_color_, max_images,
       name=scope)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:summary_ops.py


示例9: get_updates

  def get_updates(self, loss, params):
    if distribute_lib.has_distribution_strategy():
      self.updates = []

      if not params:
        # After the model vars have been created, the second call to get_updates
        # is called with params as an empty list. This ensures that we call
        # compute_gradients with params=None.
        grads = self.optimizer.compute_gradients(loss)
      else:
        grads = self.optimizer.compute_gradients(loss, params)
      global_step = training_util.get_global_step()
      opt_update = self.optimizer.apply_gradients(grads, global_step)
    else:
      if not params:
        self.updates = [state_ops.assign_add(self.iterations, 1)]
        return self.updates

      # Updates list starts out empty because the iterations variable is
      # incremented in optimizer.apply_gradients()
      self.updates = []
      grads = self.optimizer.compute_gradients(loss, params)
      opt_update = self.optimizer.apply_gradients(
          grads, global_step=self.iterations)

    self.updates.append(opt_update)
    return self.updates
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:27,代码来源:optimizers.py


示例10: __init__

    def __init__(self,
                 checkpoint_dir,
                 display_steps=100,
                 maximum_train_steps=None,
                 do_summary=True,
                 is_chief=True):
        """ Initializes the hook.

        Args:
            checkpoint_dir: A string, base directory for the checkpoint files.
            display_steps: A python integer, display every N steps.
            maximum_train_steps: A python integer, the maximum training steps.
            do_summary: Whether to save summaries when display.
            is_chief: Whether this is the chief process.do_summary:
        """

        tf.logging.info("Create DisplayHook.")
        self._checkpoint_dir = checkpoint_dir
        # display steps
        self._display_steps = display_steps
        self._maximum_train_steps = maximum_train_steps
        self._do_summary = do_summary
        self._is_chief = is_chief  # not used now

        # display values
        global_step = training_util.get_global_step()
        display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
        display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
        self._display_args = dict(zip(display_keys, display_values))
        self._display_args["global_step"] = global_step
        # timer & summary writer
        self._timer = None
        self._logging_timer = None
        self._summary_writer = None
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:34,代码来源:hooks.py


示例11: begin

 def begin(self):
   self._global_step_tensor = training_util.get_global_step()
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use CheckpointSaverHook.")
   for l in self._listeners:
     l.begin()
开发者ID:kadeng,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py


示例12: begin

 def begin(self):
   self._last_saved_step = None
   self._request_summary = True
   self._global_step_tensor = training_util.get_global_step()
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use SummarySaverHook.")
开发者ID:KalraA,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py


示例13: after_create_session

  def after_create_session(self, training_session, coord):  # pylint: disable=unused-argument
    # N.B. We have to pull the global step here to avoid it being unavailable
    # at checkpoint time; the graph has been frozen at that point.
    if training_util.get_global_step() is None and self.saver() is not None:
      raise ValueError(
          'Saver defined but no global step.  Run `get_or_create_global_step()`'
          ' in your model definition to allow checkpointing.')

    with self._graph.as_default():
      logging.info('Installing graceful shutdown hook.')
      self._session = _clone_session(training_session, self._graph)
      self._workers = WorkerHeartbeatManager.from_devices(
          self._session, all_worker_devices(self._session))
      self._heartbeat_supported = self._workers.num_workers() > 0
      if self._heartbeat_supported:
        try:
          self._workers.configure(
              event_pb2.WorkerHeartbeatRequest(
                  shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
        except errors.InvalidArgumentError:
          logging.warn(
              'TPU device does not support heartbeats. Failure '
              'handling will be disabled.')
          self._heartbeat_supported = False
      else:
        logging.warn(
            'No workers support hearbeats. Failure handling will be disabled.')
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:session_support.py


示例14: _train_op_fn

 def _train_op_fn(unused_loss):
   global_step = training_util.get_global_step()
   sdca_model, train_op = optimizer.get_train_step(
       columns_to_variables, weight_column_name, loss_type, features, labels,
       global_step)
   if update_weights_hook is not None:
     update_weights_hook.set_parameters(sdca_model, train_op)
   return train_op
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:8,代码来源:sdca_estimator.py


示例15: record_summaries_every_n_global_steps

def record_summaries_every_n_global_steps(n):
  """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
  collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
  old = collection_ref[:]
  with ops.device("cpu:0"):
    collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
  yield
  collection_ref[:] = old
开发者ID:dyoung418,项目名称:tensorflow,代码行数:8,代码来源:summary_ops.py


示例16: _train_op_fn

 def _train_op_fn(loss):
   global_step = training_util.get_global_step()
   my_vars = ops.get_collection(parent_scope)
   grads = gradients.gradients(loss, my_vars)
   if gradient_clip_norm:
     grads, _ = clip_ops.clip_by_global_norm(grads, gradient_clip_norm)
   return (_get_optimizer(optimizer).apply_gradients(
       zip(grads, my_vars), global_step=global_step))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:linear.py


示例17: begin

  def begin(self):
    self._global_step_tensor = training_util.get_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError("Global step should be created.")

    if self._override_global_step_value is not None:
      self._override_global_step_op = state_ops.assign(
          self._global_step_tensor, self._override_global_step_value)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:8,代码来源:trainer_hooks.py


示例18: model_fn

 def model_fn(features, labels):
   # dummy variable:
   _ = variables_lib.Variable([0.])
   _ = labels
   predictions = features["x"]
   loss = constant_op.constant([2.])
   update_global_step = training_util.get_global_step().assign_add(1)
   return predictions, loss, update_global_step
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:estimators_test.py


示例19: begin

 def begin(self):
   if self._summary_writer is None and self._output_dir:
     self._summary_writer = SummaryWriterCache.get(self._output_dir)
   self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
   if self._global_step_tensor is None:
     raise RuntimeError(
         "Global step should be created to use StepCounterHook.")
   self._summary_tag = training_util.get_global_step().op.name + "/sec"
开发者ID:didukhle,项目名称:tensorflow,代码行数:8,代码来源:basic_session_run_hooks.py


示例20: _train_op_fn

  def _train_op_fn(loss):
    global_step = training_util.get_global_step()
    assert global_step
    train_step = model.get_train_step(loss)

    with ops.control_dependencies(train_step):
      with ops.get_default_graph().colocate_with(global_step):
        return state_ops.assign_add(global_step, 1).op
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:composable_model_test.py



注:本文中的tensorflow.python.training.training_util.get_global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap