本文整理汇总了Python中tensorflow.python.training.training_util.get_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python get_global_step函数的具体用法?Python get_global_step怎么用?Python get_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_global_step函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_get_global_step
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
variables.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
name=ops.GraphKeys.GLOBAL_STEP)
self._assert_global_step(
training_util.get_global_step(), expected_dtype=dtypes.int32)
self._assert_global_step(
training_util.get_global_step(g), expected_dtype=dtypes.int32)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:training_util_test.py
示例2: _ModelFn
def _ModelFn(features, labels, mode):
if is_training:
logits_out = self._BuildGraph(features)
else:
graph_def = self._GetGraphDef(use_trt, batch_size, model_dir)
logits_out = importer.import_graph_def(
graph_def,
input_map={INPUT_NODE_NAME: features},
return_elements=[OUTPUT_NODE_NAME + ':0'],
name='')[0]
loss = losses.sparse_softmax_cross_entropy(
labels=labels, logits=logits_out)
summary.scalar('loss', loss)
classes_out = math_ops.argmax(logits_out, axis=1, name='classes_out')
accuracy = metrics.accuracy(
labels=labels, predictions=classes_out, name='acc_op')
summary.scalar('accuracy', accuracy[1])
if mode == ModeKeys.EVAL:
return EstimatorSpec(
mode, loss=loss, eval_metric_ops={'accuracy': accuracy})
elif mode == ModeKeys.TRAIN:
optimizer = AdamOptimizer(learning_rate=1e-2)
train_op = optimizer.minimize(loss, global_step=get_global_step())
return EstimatorSpec(mode, loss=loss, train_op=train_op)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:quantization_mnist_test.py
示例3: _train_op_fn
def _train_op_fn(loss):
"""Returns the op to optimize the loss."""
train_ops = []
global_step = training_util.get_global_step()
if dnn_logits is not None:
train_ops.append(
dnn_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=dnn_parent_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=linear_parent_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):
with ops.colocate_with(global_step):
return state_ops.assign_add(global_step, 1)
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
开发者ID:m-colombo,项目名称:tensorflow,代码行数:30,代码来源:dnn_linear_combined.py
示例4: record_summaries_every_n_global_steps
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
collection_ref[:] = [training_util.get_global_step() % n == 0]
yield
collection_ref[:] = old
开发者ID:benoitsteiner,项目名称:tensorflow-opencl,代码行数:7,代码来源:summary_ops.py
示例5: before_run
def before_run(self, run_context):
loss = (self.loss_op if self.loss_op is not None else
run_context.session.graph.get_operation_by_name(
LOSS_NAME).outputs[0])
return session_run_hook.SessionRunArgs(
{'global_step': training_util.get_global_step(),
'current_loss': loss})
开发者ID:AnishShah,项目名称:tensorflow,代码行数:7,代码来源:random_forest.py
示例6: __init__
def __init__(self):
global_step = training_util.get_global_step()
if global_step:
self._global_step_incr_op = state_ops.assign_add(
global_step, 1, name="global_step_incr").op
else:
self._global_step_incr_op = None
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:7,代码来源:wals.py
示例7: begin
def begin(self):
self._last_reported_time = None
self._last_reported_step = None
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py
示例8: function
def function(tag, scope):
if bad_color is None:
bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8)
gen_summary_ops.write_image_summary(
context.context().summary_writer_resource,
training_util.get_global_step(), tag, tensor, bad_color_, max_images,
name=scope)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:summary_ops.py
示例9: get_updates
def get_updates(self, loss, params):
if distribute_lib.has_distribution_strategy():
self.updates = []
if not params:
# After the model vars have been created, the second call to get_updates
# is called with params as an empty list. This ensures that we call
# compute_gradients with params=None.
grads = self.optimizer.compute_gradients(loss)
else:
grads = self.optimizer.compute_gradients(loss, params)
global_step = training_util.get_global_step()
opt_update = self.optimizer.apply_gradients(grads, global_step)
else:
if not params:
self.updates = [state_ops.assign_add(self.iterations, 1)]
return self.updates
# Updates list starts out empty because the iterations variable is
# incremented in optimizer.apply_gradients()
self.updates = []
grads = self.optimizer.compute_gradients(loss, params)
opt_update = self.optimizer.apply_gradients(
grads, global_step=self.iterations)
self.updates.append(opt_update)
return self.updates
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:27,代码来源:optimizers.py
示例10: __init__
def __init__(self,
checkpoint_dir,
display_steps=100,
maximum_train_steps=None,
do_summary=True,
is_chief=True):
""" Initializes the hook.
Args:
checkpoint_dir: A string, base directory for the checkpoint files.
display_steps: A python integer, display every N steps.
maximum_train_steps: A python integer, the maximum training steps.
do_summary: Whether to save summaries when display.
is_chief: Whether this is the chief process.do_summary:
"""
tf.logging.info("Create DisplayHook.")
self._checkpoint_dir = checkpoint_dir
# display steps
self._display_steps = display_steps
self._maximum_train_steps = maximum_train_steps
self._do_summary = do_summary
self._is_chief = is_chief # not used now
# display values
global_step = training_util.get_global_step()
display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
self._display_args = dict(zip(display_keys, display_values))
self._display_args["global_step"] = global_step
# timer & summary writer
self._timer = None
self._logging_timer = None
self._summary_writer = None
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:34,代码来源:hooks.py
示例11: begin
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
开发者ID:kadeng,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py
示例12: begin
def begin(self):
self._last_saved_step = None
self._request_summary = True
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use SummarySaverHook.")
开发者ID:KalraA,项目名称:tensorflow,代码行数:7,代码来源:basic_session_run_hooks.py
示例13: after_create_session
def after_create_session(self, training_session, coord): # pylint: disable=unused-argument
# N.B. We have to pull the global step here to avoid it being unavailable
# at checkpoint time; the graph has been frozen at that point.
if training_util.get_global_step() is None and self.saver() is not None:
raise ValueError(
'Saver defined but no global step. Run `get_or_create_global_step()`'
' in your model definition to allow checkpointing.')
with self._graph.as_default():
logging.info('Installing graceful shutdown hook.')
self._session = _clone_session(training_session, self._graph)
self._workers = WorkerHeartbeatManager.from_devices(
self._session, all_worker_devices(self._session))
self._heartbeat_supported = self._workers.num_workers() > 0
if self._heartbeat_supported:
try:
self._workers.configure(
event_pb2.WorkerHeartbeatRequest(
shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR))
except errors.InvalidArgumentError:
logging.warn(
'TPU device does not support heartbeats. Failure '
'handling will be disabled.')
self._heartbeat_supported = False
else:
logging.warn(
'No workers support hearbeats. Failure handling will be disabled.')
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:session_support.py
示例14: _train_op_fn
def _train_op_fn(unused_loss):
global_step = training_util.get_global_step()
sdca_model, train_op = optimizer.get_train_step(
columns_to_variables, weight_column_name, loss_type, features, labels,
global_step)
if update_weights_hook is not None:
update_weights_hook.set_parameters(sdca_model, train_op)
return train_op
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:8,代码来源:sdca_estimator.py
示例15: record_summaries_every_n_global_steps
def record_summaries_every_n_global_steps(n):
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
with ops.device("cpu:0"):
collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)]
yield
collection_ref[:] = old
开发者ID:dyoung418,项目名称:tensorflow,代码行数:8,代码来源:summary_ops.py
示例16: _train_op_fn
def _train_op_fn(loss):
global_step = training_util.get_global_step()
my_vars = ops.get_collection(parent_scope)
grads = gradients.gradients(loss, my_vars)
if gradient_clip_norm:
grads, _ = clip_ops.clip_by_global_norm(grads, gradient_clip_norm)
return (_get_optimizer(optimizer).apply_gradients(
zip(grads, my_vars), global_step=global_step))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:linear.py
示例17: begin
def begin(self):
self._global_step_tensor = training_util.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created.")
if self._override_global_step_value is not None:
self._override_global_step_op = state_ops.assign(
self._global_step_tensor, self._override_global_step_value)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:8,代码来源:trainer_hooks.py
示例18: model_fn
def model_fn(features, labels):
# dummy variable:
_ = variables_lib.Variable([0.])
_ = labels
predictions = features["x"]
loss = constant_op.constant([2.])
update_global_step = training_util.get_global_step().assign_add(1)
return predictions, loss, update_global_step
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:estimators_test.py
示例19: begin
def begin(self):
if self._summary_writer is None and self._output_dir:
self._summary_writer = SummaryWriterCache.get(self._output_dir)
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use StepCounterHook.")
self._summary_tag = training_util.get_global_step().op.name + "/sec"
开发者ID:didukhle,项目名称:tensorflow,代码行数:8,代码来源:basic_session_run_hooks.py
示例20: _train_op_fn
def _train_op_fn(loss):
global_step = training_util.get_global_step()
assert global_step
train_step = model.get_train_step(loss)
with ops.control_dependencies(train_step):
with ops.get_default_graph().colocate_with(global_step):
return state_ops.assign_add(global_step, 1).op
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:composable_model_test.py
注:本文中的tensorflow.python.training.training_util.get_global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论