• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python training_util.global_step函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.training_util.global_step函数的典型用法代码示例。如果您正苦于以下问题:Python global_step函数的具体用法?Python global_step怎么用?Python global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了global_step函数的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: run_loop

 def run_loop(self):
     self._sv.saver.save(self._sess, self._sv.save_path, global_step=self._sv.global_step)
     if self._sv.summary_writer and self._sv.global_step is not None:
         current_step = training_util.global_step(self._sess, self._sv.global_step)
         self._sv.summary_writer.add_session_log(
             SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path), current_step
         )
开发者ID:paolodedios,项目名称:tensorflow,代码行数:7,代码来源:supervisor.py


示例2: end

 def end(self, session):
   if self._summary_op is not None:
     global_step = training_util.global_step(session, self._global_step)
     summary_str = session.run(self._summary_op, self._feed_dict)
     if self._summary_writer:
       self._summary_writer.add_summary(summary_str, global_step)
   if self._summary_writer:
     self._summary_writer.flush()
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:8,代码来源:evaluation.py


示例3: save

  def save(self, sess, save_path, global_step=None, latest_filename=None):
    """Saves variables.

    This method runs the ops added by the constructor for saving variables.
    It requires a session in which the graph was launched.  The variables to
    save must also have been initialized.

    The method returns the path of the newly created checkpoint file.  This
    path can be passed directly to a call to `restore()`.

    Args:
      sess: A Session to use to save the variables.
      save_path: String.  Path to the checkpoint filename.  If the saver is
        `sharded`, this is the prefix of the sharded checkpoint filename.
      global_step: If provided the global step number is appended to
        `save_path` to create the checkpoint filename. The optional argument
        can be a `Tensor`, a `Tensor` name or an integer.
      latest_filename: Optional name for the protocol buffer file that will
        contains the list of most recent checkpoint filenames.  That file,
        kept in the same directory as the checkpoint files, is automatically
        managed by the saver to keep track of recent checkpoints.  Defaults to
        'checkpoint'.

    Returns:
      A string: path at which the variables were saved.  If the saver is
        sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
        is the number of shards created.

    Raises:
      TypeError: If `sess` is not a `Session`.
      ValueError: If `latest_filename` contains path components.
    """
    if latest_filename is None:
      latest_filename = "checkpoint"

    if os.path.split(latest_filename)[0]:
      raise ValueError("'latest_filename' must not contain path components")

    if global_step is not None:
      if not isinstance(global_step, compat.integral_types):
        global_step = training_util.global_step(sess, global_step)
      checkpoint_file = "%s-%d" % (save_path, global_step)
    else:
      checkpoint_file = save_path
    save_path = os.path.dirname(save_path)
    if not isinstance(sess, session.SessionInterface):
      raise TypeError("'sess' must be a Session; %s" % sess)

    model_checkpoint_path = sess.run(
        self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
    model_checkpoint_path = compat.as_str(model_checkpoint_path)
    self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
    update_checkpoint_state(save_path, model_checkpoint_path,
                            self.last_checkpoints, latest_filename)
    return model_checkpoint_path
开发者ID:hessenh,项目名称:Human-Activity-Recognition,代码行数:55,代码来源:saver.py


示例4: save

  def save(self, session=None, checkpoint_number=None):
    """Creates a new checkpoint and manages it.

    Args:
      session: The session to evaluate variables in. Ignored when executing
        eagerly. If not provided when graph building, the default session is
        used.
      checkpoint_number: An optional integer, or an integer-dtype `Variable` or
        `Tensor`, used to number the checkpoint. If `None` (default),
        checkpoints are numbered using `checkpoint.save_counter`. Even if
        `checkpoint_number` is provided, `save_counter` is still incremented. A
        user-provided `checkpoint_number` is not incremented even if it is a
        `Variable`.

    Returns:
      The path to the new checkpoint. It is also recorded in the `checkpoints`
      and `latest_checkpoint` properies.
    """
    # Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
    # slightly with a custom numbering option.
    if context.executing_eagerly():
      save_counter = self._checkpoint.save_counter
      save_counter.assign_add(1)
    else:
      if session is None:
        session = ops.get_default_session()

      def _initializing_creator(next_creator, **kwargs):
        """Initialize the save counter if it has been newly created."""
        v = next_creator(**kwargs)
        session.run(v.initializer)
        return v

      with variable_scope.variable_creator_scope(_initializing_creator):
        save_counter = self._checkpoint.save_counter
      if self._save_counter_assign is None:
        self._save_counter_assign = save_counter.assign_add(1, read_value=False)
      session.run(self._save_counter_assign)
    if checkpoint_number is None:
      checkpoint_number = save_counter
    if not isinstance(checkpoint_number, compat.integral_types):
      checkpoint_number = training_util.global_step(
          sess=session, global_step_tensor=checkpoint_number)
    prefix = "%s-%d" % (self._prefix, checkpoint_number)
    save_path = self._checkpoint.write(prefix)
    timestamp = time.time()
    # If this is an overwritten checkpoint we were previously tracking, delete
    # and reinsert it to make sure it goes to the end of the queue.
    if save_path in self._maybe_delete:
      del self._maybe_delete[save_path]
    self._maybe_delete[save_path] = timestamp
    self._latest_checkpoint = save_path
    self._sweep()
    self._record_state()
    return save_path
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:checkpoint_management.py


示例5: start_standard_services

  def start_standard_services(self, sess):
    """Start the standard services for 'sess'.

    This starts services in the background.  The services started depend
    on the parameters to the constructor and may include:

      - A Summary thread computing summaries every save_summaries_secs.
      - A Checkpoint thread saving the model every save_model_secs.
      - A StepCounter thread measure step time.

    Args:
      sess: A Session.

    Returns:
      A list of threads that are running the standard services.  You can use
      the Supervisor's Coordinator to join these threads with:
        sv.coord.Join(<list of threads>)

    Raises:
      RuntimeError: If called with a non-chief Supervisor.
      ValueError: If not `logdir` was passed to the constructor as the
        services need a log directory.
    """
    if not self._is_chief:
      raise RuntimeError("Only chief supervisor can start standard services. "
                         "Because only chief supervisors can write events.")

    if not self._logdir:
      logging.warning("Standard services need a 'logdir' "
                      "passed to the SessionManager")
      return

    if self._global_step is not None and self._summary_writer:
      # Only add the session log if we keep track of global step.
      # TensorBoard cannot use START message for purging expired events
      # if there is no step value.
      current_step = training_util.global_step(sess, self._global_step)
      self._summary_writer.add_session_log(
          SessionLog(status=SessionLog.START),
          current_step)

    threads = []
    if self._save_summaries_secs and self._summary_writer:
      if self._summary_op is not None:
        threads.append(SVSummaryThread(self, sess))
      if self._global_step is not None:
        threads.append(SVStepCounterThread(self, sess))
    if self.saver and self._save_model_secs:
      threads.append(SVTimerCheckpointThread(self, sess))
    for t in threads:
      t.start()
    self._started_threads.extend(threads)

    return threads
开发者ID:01bui,项目名称:tensorflow,代码行数:54,代码来源:supervisor.py


示例6: _wait_for_step

def _wait_for_step(sess, global_step, step):
    """Wait till the global step has reached at least 'step'.

    Args:
      sess: A session.
      global_step: A Tensor.
      step: Int.  The global step to reach.
    """
    while True:
        if training_util.global_step(sess, global_step) >= step:
            break
        time.sleep(1.0)
开发者ID:astorfi,项目名称:tensorflow,代码行数:12,代码来源:learning.py


示例7: summary_computed

  def summary_computed(self, sess, summary, global_step=None):
    """Indicate that a summary was computed.

    Args:
      sess: A `Session` object.
      summary: A Summary proto, or a string holding a serialized summary proto.
      global_step: Int. global step this summary is associated with. If `None`,
        it will try to fetch the current step.

    Raises:
      TypeError: if 'summary' is not a Summary proto or a string.
      RuntimeError: if the Supervisor was created without a `logdir`.
    """
    if not self._summary_writer:
      raise RuntimeError("Writing a summary requires a summary writer.")
    if global_step is None and self.global_step is not None:
      global_step = training_util.global_step(sess, self.global_step)
    self._summary_writer.add_summary(summary, global_step)
开发者ID:01bui,项目名称:tensorflow,代码行数:18,代码来源:supervisor.py


示例8: evaluation

def evaluation(sess,
               num_evals=1,
               initial_op=None,
               initial_op_feed_dict=None,
               eval_op=None,
               eval_op_feed_dict=None,
               final_op=None,
               final_op_feed_dict=None,
               summary_op=None,
               summary_op_feed_dict=None,
               summary_writer=None,
               global_step=None):
  """Performs a single evaluation run.

  A single evaluation consists of several steps run in the following order:
  (1) an initialization op, (2) an evaluation op which is executed `num_evals`
  times (3) a finalization op and (4) the execution of a summary op which is
  written out using a summary writer.

  Args:
    sess: The current TensorFlow `Session`.
    num_evals: The number of times to execute `eval_op`.
    initial_op: An operation run at the beginning of evaluation.
    initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
    eval_op: A operation run `num_evals` times.
    eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
    final_op: An operation to execute after all of the `eval_op` executions. The
      value of `final_op` is returned.
    final_op_feed_dict: A feed dictionary to use when executing `final_op`.
    summary_op: A summary op executed after `eval_op` and `finalize_op`.
    summary_op_feed_dict: An optional feed dictionary to use when executing the
      `summary_op`.
    summary_writer: The summery writer used if `summary_op` is provided.
    global_step: the global step variable. If left as `None`, then
      slim.variables.global_step() is used.

  Returns:
    The value of `final_op` or `None` if `final_op` is `None`.

  Raises:
    ValueError: if `summary_op` is provided but `global_step` is `None`.
  """
  if initial_op is not None:
    logging.info('Executing initial eval op')
    sess.run(initial_op, initial_op_feed_dict)

  if eval_op is not None:
    logging.info('Executing eval ops')
    for i in range(int(num_evals)):
      logging.info('Executing eval_op %d/%d', i + 1, num_evals)
      sess.run(eval_op, eval_op_feed_dict)

  if final_op is not None:
    logging.info('Executing final op')
    final_op_value = sess.run(final_op, final_op_feed_dict)
  else:
    final_op_value = None

  if summary_op is not None:
    logging.info('Executing summary op')
    if global_step is None:
      global_step = variables.get_or_create_global_step()

    global_step = training_util.global_step(sess, global_step)
    summary = sess.run(summary_op, summary_op_feed_dict)
    summary_writer.add_summary(summary, global_step)
    summary_writer.flush()

  return final_op_value
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:69,代码来源:evaluation.py


示例9: start_loop

 def start_loop(self):
   self._last_time = time.time()
   self._last_step = training_util.global_step(
       self._sess, self._sv.global_step)
开发者ID:Anandnitrate,项目名称:tensorflow,代码行数:4,代码来源:supervisor.py


示例10: export

  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Returns:
      The string path to the exported directory.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    # Export dir must not end with / or it will break exports to keep. Strip /.
    if export_dir_base.endswith("/"):
      export_dir_base = export_dir_base[:-1]

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(
        compat.as_bytes(export_dir_base),
        compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = compat.as_text(export_dir) + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(sess,
                     os.path.join(
                         compat.as_text(tmp_export_dir),
                         compat.as_text(constants.EXPORT_BASE_NAME)),
                     meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback and self._assets_to_copy:
      assets_dir = os.path.join(
          compat.as_bytes(tmp_export_dir),
          compat.as_bytes(constants.ASSETS_DIRECTORY))
      gfile.MakeDirs(assets_dir)
      self._assets_callback(self._assets_to_copy, assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)

    return export_dir
开发者ID:2020zyc,项目名称:tensorflow,代码行数:77,代码来源:exporter.py



注:本文中的tensorflow.python.training.training_util.global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python training_util.write_graph函数代码示例发布时间:2022-05-27
下一篇:
Python training_util.get_or_create_global_step函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap