本文整理汇总了Python中tensorflow.python.training.training_util._get_or_create_global_step_read函数的典型用法代码示例。如果您正苦于以下问题:Python _get_or_create_global_step_read函数的具体用法?Python _get_or_create_global_step_read怎么用?Python _get_or_create_global_step_read使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了_get_or_create_global_step_read函数的7个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: begin
def begin(self):
if self._summary_writer is None and self._output_dir:
self._summary_writer = SummaryWriterCache.get(self._output_dir)
self._next_step = None
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use SummarySaverHook.")
开发者ID:didukhle,项目名称:tensorflow,代码行数:8,代码来源:basic_session_run_hooks.py
示例2: begin
def begin(self):
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
开发者ID:becster,项目名称:tensorflow,代码行数:8,代码来源:async_checkpoint.py
示例3: test_reads_before_increments
def test_reads_before_increments(self):
with ops.Graph().as_default():
training_util.create_global_step()
read_tensor = training_util._get_or_create_global_step_read()
inc_op = training_util._increment_global_step(1)
inc_three_op = training_util._increment_global_step(3)
with monitored_session.MonitoredTrainingSession() as sess:
read_value, _ = sess.run([read_tensor, inc_op])
self.assertEqual(0, read_value)
read_value, _ = sess.run([read_tensor, inc_three_op])
self.assertEqual(1, read_value)
read_value = sess.run(read_tensor)
self.assertEqual(4, read_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:13,代码来源:training_util_test.py
示例4: _train_model
def _train_model(self, input_fn, hooks, saving_listeners):
worker_hooks = []
with ops.Graph().as_default() as g, g.device(self._device_fn):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
global_step_read_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
with ops.control_dependencies([global_step_read_tensor]):
features, labels = self._get_features_and_labels_from_input_fn(
input_fn, model_fn_lib.ModeKeys.TRAIN)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
# Check if the user created a loss summary, and add one if they didn't.
# We assume here that the summary is called 'loss'. If it is not, we will
# make another one with the name 'loss' to ensure it shows up in the right
# graph in TensorBoard.
if not any([x.op.name == 'loss'
for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
summary.scalar('loss', estimator_spec.loss)
ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
worker_hooks.extend(hooks)
worker_hooks.extend([
training.NanTensorHook(estimator_spec.loss),
training.LoggingTensorHook(
{
'loss': estimator_spec.loss,
'step': global_step_tensor
},
every_n_iter=100)
])
worker_hooks.extend(estimator_spec.training_hooks)
if not (estimator_spec.scaffold.saver or
ops.get_collection(ops.GraphKeys.SAVERS)):
ops.add_to_collection(
ops.GraphKeys.SAVERS,
training.Saver(
sharded=True,
max_to_keep=self._config.keep_checkpoint_max,
keep_checkpoint_every_n_hours=(
self._config.keep_checkpoint_every_n_hours),
defer_build=True,
save_relative_paths=True))
chief_hooks = []
all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
saver_hooks = [
h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
if (self._config.save_checkpoints_secs or
self._config.save_checkpoints_steps):
if not saver_hooks:
chief_hooks = [
training.CheckpointSaverHook(
self._model_dir,
save_secs=self._config.save_checkpoints_secs,
save_steps=self._config.save_checkpoints_steps,
scaffold=estimator_spec.scaffold)
]
saver_hooks = [chief_hooks[0]]
if saving_listeners:
if not saver_hooks:
raise ValueError(
'There should be a CheckpointSaverHook to use saving_listeners. '
'Please set one of the RunConfig.save_checkpoints_steps or '
'RunConfig.save_checkpoints_secs.')
else:
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
checkpoint_dir=self._model_dir,
scaffold=estimator_spec.scaffold,
hooks=worker_hooks,
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
save_summaries_steps=self._config.save_summary_steps,
config=self._session_config,
log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
return loss
开发者ID:ilya-edrenkin,项目名称:tensorflow,代码行数:84,代码来源:estimator.py
示例5: begin
def begin(self):
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
'Global step should be created to use StopAtCheckpointStepHook.')
开发者ID:Jordan1237,项目名称:tensorflow,代码行数:5,代码来源:hooks.py
示例6: test_global_step_read_is_none_if_there_is_no_global_step
def test_global_step_read_is_none_if_there_is_no_global_step(self):
with ops.Graph().as_default():
self.assertIsNone(training_util._get_or_create_global_step_read())
training_util.create_global_step()
self.assertIsNotNone(training_util._get_or_create_global_step_read())
开发者ID:aeverall,项目名称:tensorflow,代码行数:5,代码来源:training_util_test.py
示例7: test_reads_from_cache
def test_reads_from_cache(self):
with ops.Graph().as_default():
training_util.create_global_step()
first = training_util._get_or_create_global_step_read()
second = training_util._get_or_create_global_step_read()
self.assertEqual(first, second)
开发者ID:aeverall,项目名称:tensorflow,代码行数:6,代码来源:training_util_test.py
注:本文中的tensorflow.python.training.training_util._get_or_create_global_step_read函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论