本文整理汇总了Python中tensorflow.python.training.training_util.global_step函数的典型用法代码示例。如果您正苦于以下问题:Python global_step函数的具体用法?Python global_step怎么用?Python global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了global_step函数的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: run_loop
def run_loop(self):
self._sv.saver.save(self._sess, self._sv.save_path, global_step=self._sv.global_step)
if self._sv.summary_writer and self._sv.global_step is not None:
current_step = training_util.global_step(self._sess, self._sv.global_step)
self._sv.summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT, checkpoint_path=self._sv.save_path), current_step
)
开发者ID:paolodedios,项目名称:tensorflow,代码行数:7,代码来源:supervisor.py
示例2: end
def end(self, session):
if self._summary_op is not None:
global_step = training_util.global_step(session, self._global_step)
summary_str = session.run(self._summary_op, self._feed_dict)
if self._summary_writer:
self._summary_writer.add_summary(summary_str, global_step)
if self._summary_writer:
self._summary_writer.flush()
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:8,代码来源:evaluation.py
示例3: save
def save(self, sess, save_path, global_step=None, latest_filename=None):
"""Saves variables.
This method runs the ops added by the constructor for saving variables.
It requires a session in which the graph was launched. The variables to
save must also have been initialized.
The method returns the path of the newly created checkpoint file. This
path can be passed directly to a call to `restore()`.
Args:
sess: A Session to use to save the variables.
save_path: String. Path to the checkpoint filename. If the saver is
`sharded`, this is the prefix of the sharded checkpoint filename.
global_step: If provided the global step number is appended to
`save_path` to create the checkpoint filename. The optional argument
can be a `Tensor`, a `Tensor` name or an integer.
latest_filename: Optional name for the protocol buffer file that will
contains the list of most recent checkpoint filenames. That file,
kept in the same directory as the checkpoint files, is automatically
managed by the saver to keep track of recent checkpoints. Defaults to
'checkpoint'.
Returns:
A string: path at which the variables were saved. If the saver is
sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
is the number of shards created.
Raises:
TypeError: If `sess` is not a `Session`.
ValueError: If `latest_filename` contains path components.
"""
if latest_filename is None:
latest_filename = "checkpoint"
if os.path.split(latest_filename)[0]:
raise ValueError("'latest_filename' must not contain path components")
if global_step is not None:
if not isinstance(global_step, compat.integral_types):
global_step = training_util.global_step(sess, global_step)
checkpoint_file = "%s-%d" % (save_path, global_step)
else:
checkpoint_file = save_path
save_path = os.path.dirname(save_path)
if not isinstance(sess, session.SessionInterface):
raise TypeError("'sess' must be a Session; %s" % sess)
model_checkpoint_path = sess.run(
self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
update_checkpoint_state(save_path, model_checkpoint_path,
self.last_checkpoints, latest_filename)
return model_checkpoint_path
开发者ID:hessenh,项目名称:Human-Activity-Recognition,代码行数:55,代码来源:saver.py
示例4: save
def save(self, session=None, checkpoint_number=None):
"""Creates a new checkpoint and manages it.
Args:
session: The session to evaluate variables in. Ignored when executing
eagerly. If not provided when graph building, the default session is
used.
checkpoint_number: An optional integer, or an integer-dtype `Variable` or
`Tensor`, used to number the checkpoint. If `None` (default),
checkpoints are numbered using `checkpoint.save_counter`. Even if
`checkpoint_number` is provided, `save_counter` is still incremented. A
user-provided `checkpoint_number` is not incremented even if it is a
`Variable`.
Returns:
The path to the new checkpoint. It is also recorded in the `checkpoints`
and `latest_checkpoint` properies.
"""
# Save counter logic duplicated from tf.train.Checkpoint, soon to diverge
# slightly with a custom numbering option.
if context.executing_eagerly():
save_counter = self._checkpoint.save_counter
save_counter.assign_add(1)
else:
if session is None:
session = ops.get_default_session()
def _initializing_creator(next_creator, **kwargs):
"""Initialize the save counter if it has been newly created."""
v = next_creator(**kwargs)
session.run(v.initializer)
return v
with variable_scope.variable_creator_scope(_initializing_creator):
save_counter = self._checkpoint.save_counter
if self._save_counter_assign is None:
self._save_counter_assign = save_counter.assign_add(1, read_value=False)
session.run(self._save_counter_assign)
if checkpoint_number is None:
checkpoint_number = save_counter
if not isinstance(checkpoint_number, compat.integral_types):
checkpoint_number = training_util.global_step(
sess=session, global_step_tensor=checkpoint_number)
prefix = "%s-%d" % (self._prefix, checkpoint_number)
save_path = self._checkpoint.write(prefix)
timestamp = time.time()
# If this is an overwritten checkpoint we were previously tracking, delete
# and reinsert it to make sure it goes to the end of the queue.
if save_path in self._maybe_delete:
del self._maybe_delete[save_path]
self._maybe_delete[save_path] = timestamp
self._latest_checkpoint = save_path
self._sweep()
self._record_state()
return save_path
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:checkpoint_management.py
示例5: start_standard_services
def start_standard_services(self, sess):
"""Start the standard services for 'sess'.
This starts services in the background. The services started depend
on the parameters to the constructor and may include:
- A Summary thread computing summaries every save_summaries_secs.
- A Checkpoint thread saving the model every save_model_secs.
- A StepCounter thread measure step time.
Args:
sess: A Session.
Returns:
A list of threads that are running the standard services. You can use
the Supervisor's Coordinator to join these threads with:
sv.coord.Join(<list of threads>)
Raises:
RuntimeError: If called with a non-chief Supervisor.
ValueError: If not `logdir` was passed to the constructor as the
services need a log directory.
"""
if not self._is_chief:
raise RuntimeError("Only chief supervisor can start standard services. "
"Because only chief supervisors can write events.")
if not self._logdir:
logging.warning("Standard services need a 'logdir' "
"passed to the SessionManager")
return
if self._global_step is not None and self._summary_writer:
# Only add the session log if we keep track of global step.
# TensorBoard cannot use START message for purging expired events
# if there is no step value.
current_step = training_util.global_step(sess, self._global_step)
self._summary_writer.add_session_log(
SessionLog(status=SessionLog.START),
current_step)
threads = []
if self._save_summaries_secs and self._summary_writer:
if self._summary_op is not None:
threads.append(SVSummaryThread(self, sess))
if self._global_step is not None:
threads.append(SVStepCounterThread(self, sess))
if self.saver and self._save_model_secs:
threads.append(SVTimerCheckpointThread(self, sess))
for t in threads:
t.start()
self._started_threads.extend(threads)
return threads
开发者ID:01bui,项目名称:tensorflow,代码行数:54,代码来源:supervisor.py
示例6: _wait_for_step
def _wait_for_step(sess, global_step, step):
"""Wait till the global step has reached at least 'step'.
Args:
sess: A session.
global_step: A Tensor.
step: Int. The global step to reach.
"""
while True:
if training_util.global_step(sess, global_step) >= step:
break
time.sleep(1.0)
开发者ID:astorfi,项目名称:tensorflow,代码行数:12,代码来源:learning.py
示例7: summary_computed
def summary_computed(self, sess, summary, global_step=None):
"""Indicate that a summary was computed.
Args:
sess: A `Session` object.
summary: A Summary proto, or a string holding a serialized summary proto.
global_step: Int. global step this summary is associated with. If `None`,
it will try to fetch the current step.
Raises:
TypeError: if 'summary' is not a Summary proto or a string.
RuntimeError: if the Supervisor was created without a `logdir`.
"""
if not self._summary_writer:
raise RuntimeError("Writing a summary requires a summary writer.")
if global_step is None and self.global_step is not None:
global_step = training_util.global_step(sess, self.global_step)
self._summary_writer.add_summary(summary, global_step)
开发者ID:01bui,项目名称:tensorflow,代码行数:18,代码来源:supervisor.py
示例8: evaluation
def evaluation(sess,
num_evals=1,
initial_op=None,
initial_op_feed_dict=None,
eval_op=None,
eval_op_feed_dict=None,
final_op=None,
final_op_feed_dict=None,
summary_op=None,
summary_op_feed_dict=None,
summary_writer=None,
global_step=None):
"""Performs a single evaluation run.
A single evaluation consists of several steps run in the following order:
(1) an initialization op, (2) an evaluation op which is executed `num_evals`
times (3) a finalization op and (4) the execution of a summary op which is
written out using a summary writer.
Args:
sess: The current TensorFlow `Session`.
num_evals: The number of times to execute `eval_op`.
initial_op: An operation run at the beginning of evaluation.
initial_op_feed_dict: A feed dictionary to use when executing `initial_op`.
eval_op: A operation run `num_evals` times.
eval_op_feed_dict: The feed dictionary to use when executing the `eval_op`.
final_op: An operation to execute after all of the `eval_op` executions. The
value of `final_op` is returned.
final_op_feed_dict: A feed dictionary to use when executing `final_op`.
summary_op: A summary op executed after `eval_op` and `finalize_op`.
summary_op_feed_dict: An optional feed dictionary to use when executing the
`summary_op`.
summary_writer: The summery writer used if `summary_op` is provided.
global_step: the global step variable. If left as `None`, then
slim.variables.global_step() is used.
Returns:
The value of `final_op` or `None` if `final_op` is `None`.
Raises:
ValueError: if `summary_op` is provided but `global_step` is `None`.
"""
if initial_op is not None:
logging.info('Executing initial eval op')
sess.run(initial_op, initial_op_feed_dict)
if eval_op is not None:
logging.info('Executing eval ops')
for i in range(int(num_evals)):
logging.info('Executing eval_op %d/%d', i + 1, num_evals)
sess.run(eval_op, eval_op_feed_dict)
if final_op is not None:
logging.info('Executing final op')
final_op_value = sess.run(final_op, final_op_feed_dict)
else:
final_op_value = None
if summary_op is not None:
logging.info('Executing summary op')
if global_step is None:
global_step = variables.get_or_create_global_step()
global_step = training_util.global_step(sess, global_step)
summary = sess.run(summary_op, summary_op_feed_dict)
summary_writer.add_summary(summary, global_step)
summary_writer.flush()
return final_op_value
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:69,代码来源:evaluation.py
示例9: start_loop
def start_loop(self):
self._last_time = time.time()
self._last_step = training_util.global_step(
self._sess, self._sv.global_step)
开发者ID:Anandnitrate,项目名称:tensorflow,代码行数:4,代码来源:supervisor.py
示例10: export
def export(self,
export_dir_base,
global_step_tensor,
sess=None,
exports_to_keep=None):
"""Exports the model.
Args:
export_dir_base: A string path to the base export dir.
global_step_tensor: An Tensor or tensor name providing the
global step counter to append to the export directory path and set
in the manifest version.
sess: A Session to use to save the parameters.
exports_to_keep: a gc.Path filter function used to determine the set of
exports to keep. If set to None, all versions will be kept.
Returns:
The string path to the exported directory.
Raises:
RuntimeError: if init is not called.
RuntimeError: if the export would overwrite an existing directory.
"""
if not self._has_init:
raise RuntimeError("init must be called first")
# Export dir must not end with / or it will break exports to keep. Strip /.
if export_dir_base.endswith("/"):
export_dir_base = export_dir_base[:-1]
global_step = training_util.global_step(sess, global_step_tensor)
export_dir = os.path.join(
compat.as_bytes(export_dir_base),
compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))
# Prevent overwriting on existing exports which could lead to bad/corrupt
# storage and loading of models. This is an important check that must be
# done before any output files or directories are created.
if gfile.Exists(export_dir):
raise RuntimeError("Overwriting exports can cause corruption and are "
"not allowed. Duplicate export dir: %s" % export_dir)
# Output to a temporary directory which is atomically renamed to the final
# directory when complete.
tmp_export_dir = compat.as_text(export_dir) + "-tmp"
gfile.MakeDirs(tmp_export_dir)
self._saver.save(sess,
os.path.join(
compat.as_text(tmp_export_dir),
compat.as_text(constants.EXPORT_BASE_NAME)),
meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)
# Run the asset callback.
if self._assets_callback and self._assets_to_copy:
assets_dir = os.path.join(
compat.as_bytes(tmp_export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY))
gfile.MakeDirs(assets_dir)
self._assets_callback(self._assets_to_copy, assets_dir)
# TODO(b/27794910): Delete *checkpoint* file before rename.
gfile.Rename(tmp_export_dir, export_dir)
if exports_to_keep:
# create a simple parser that pulls the export_version from the directory.
def parser(path):
match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
if not match:
return None
return path._replace(export_version=int(match.group(1)))
paths_to_delete = gc.negation(exports_to_keep)
for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
gfile.DeleteRecursively(p.path)
return export_dir
开发者ID:2020zyc,项目名称:tensorflow,代码行数:77,代码来源:exporter.py
注:本文中的tensorflow.python.training.training_util.global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论