本文整理汇总了Python中tensorflow.python.training.training_util.create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python create_global_step函数的具体用法?Python create_global_step怎么用?Python create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了create_global_step函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _test_logits
def _test_logits(self, mode, rnn_units, logits_dimension, features_fn,
sequence_feature_columns, context_feature_columns,
expected_logits):
"""Tests that the expected logits are calculated."""
with ops.Graph().as_default():
# Global step needed for MonitoredSession, which is in turn used to
# explicitly set variable weights through a checkpoint.
training_util.create_global_step()
# Use a variable scope here with 'rnn', emulating the rnn model_fn, so
# the checkpoint naming is shared.
with variable_scope.variable_scope('rnn'):
input_layer_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=0, min_slice_size=64 << 20))
logit_fn = rnn._rnn_logit_fn_builder(
output_units=logits_dimension,
rnn_cell_fn=rnn._make_rnn_cell_fn(rnn_units),
sequence_feature_columns=sequence_feature_columns,
context_feature_columns=context_feature_columns,
input_layer_partitioner=input_layer_partitioner)
# Features are constructed within this function, otherwise the Tensors
# containing the features would be defined outside this graph.
logits = logit_fn(features=features_fn(), mode=mode)
with monitored_session.MonitoredTrainingSession(
checkpoint_dir=self._model_dir) as sess:
self.assertAllClose(expected_logits, sess.run(logits), atol=1e-4)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:26,代码来源:rnn_test.py
示例2: _save_first_checkpoint
def _save_first_checkpoint(keras_model, estimator, custom_objects,
keras_weights):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
estimator: keras estimator.
custom_objects: Dictionary for custom objects.
keras_weights: A flat list of Numpy arrays for weights of given keras_model.
Returns:
The model_fn for a keras Estimator.
"""
# Load weights and save to checkpoint if there is no checkpoint
latest_path = saver_lib.latest_checkpoint(estimator.model_dir)
if not latest_path:
with ops.Graph().as_default():
random_seed.set_random_seed(estimator.config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
# save to checkpoint
with session.Session(config=estimator._session_config) as sess:
if keras_weights:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
saver.save(sess, os.path.join(estimator.model_dir, 'keras_model.ckpt'))
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:33,代码来源:keras.py
示例3: _test_logits
def _test_logits(
self, mode, hidden_units, logits_dimension, inputs, expected_logits):
"""Tests that the expected logits are passed to mock head."""
with ops.Graph().as_default():
training_util.create_global_step()
head = _mock_head(
self,
hidden_units=hidden_units,
logits_dimension=logits_dimension,
expected_logits=expected_logits)
estimator_spec = dnn._dnn_model_fn(
features={'age': constant_op.constant(inputs)},
labels=constant_op.constant([[1]]),
mode=mode,
head=head,
hidden_units=hidden_units,
feature_columns=[
feature_column.numeric_column('age',
shape=np.array(inputs).shape[1:])],
optimizer=_mock_optimizer(self, hidden_units))
with monitored_session.MonitoredTrainingSession(
checkpoint_dir=self._model_dir) as sess:
if mode == model_fn.ModeKeys.TRAIN:
sess.run(estimator_spec.train_op)
elif mode == model_fn.ModeKeys.EVAL:
sess.run(estimator_spec.loss)
elif mode == model_fn.ModeKeys.PREDICT:
sess.run(estimator_spec.predictions)
else:
self.fail('Invalid mode: {}'.format(mode))
开发者ID:cameronphchen,项目名称:tensorflow,代码行数:30,代码来源:dnn_test.py
示例4: test_features_tensor_raises_value_error
def test_features_tensor_raises_value_error(self):
"""Tests that passing a Tensor for features raises a ValueError."""
hidden_units = (2, 2)
logits_dimension = 3
inputs = ([[10.]], [[8.]])
expected_logits = [[0, 0, 0]]
with ops.Graph().as_default():
training_util.create_global_step()
head = mock_head(
self,
hidden_units=hidden_units,
logits_dimension=logits_dimension,
expected_logits=expected_logits)
with self.assertRaisesRegexp(ValueError, 'features should be a dict'):
self._dnn_model_fn(
features=constant_op.constant(inputs),
labels=constant_op.constant([[1]]),
mode=model_fn.ModeKeys.TRAIN,
head=head,
hidden_units=hidden_units,
feature_columns=[
feature_column.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:26,代码来源:dnn_testing_utils.py
示例5: run_session
def run_session(self, hooks, should_stop):
hooks = hooks if isinstance(hooks, list) else [hooks]
with ops.Graph().as_default():
training_util.create_global_step()
no_op = control_flow_ops.no_op()
with monitored_session.SingularMonitoredSession(hooks=hooks) as mon_sess:
mon_sess.run(no_op)
self.assertEqual(mon_sess.should_stop(), should_stop)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:8,代码来源:early_stopping_test.py
示例6: test_create_global_step
def test_create_global_step(self):
self.assertIsNone(training_util.get_global_step())
with ops.Graph().as_default() as g:
global_step = training_util.create_global_step()
self._assert_global_step(global_step)
self.assertRaisesRegexp(ValueError, 'already exists',
training_util.create_global_step)
self.assertRaisesRegexp(ValueError, 'already exists',
training_util.create_global_step, g)
self._assert_global_step(training_util.create_global_step(ops.Graph()))
开发者ID:aeverall,项目名称:tensorflow,代码行数:10,代码来源:training_util_test.py
示例7: test_multi_feature_column_multi_dim_logits
def test_multi_feature_column_multi_dim_logits(self):
"""Tests multiple feature columns and multi-dimensional logits.
All numbers are the same as test_multi_dim_input_multi_dim_logits. The only
difference is that the input consists of two 1D feature columns, instead of
one 2D feature column.
"""
base_global_step = 100
create_checkpoint((([[.6, .5], [-.6, -.5]],
[.1, -.1]), ([[1., .8], [-.8, -1.]], [.2, -.2]),
([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),),
base_global_step, self._model_dir)
hidden_units = (2, 2)
logits_dimension = 3
inputs = ([[10.]], [[8.]])
expected_logits = [[-0.48, 0.48, 0.39]]
for mode in [
model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL,
model_fn.ModeKeys.PREDICT
]:
with ops.Graph().as_default():
training_util.create_global_step()
head = mock_head(
self,
hidden_units=hidden_units,
logits_dimension=logits_dimension,
expected_logits=expected_logits)
estimator_spec = self._dnn_model_fn(
features={
'age': constant_op.constant(inputs[0]),
'height': constant_op.constant(inputs[1])
},
labels=constant_op.constant([[1]]),
mode=mode,
head=head,
hidden_units=hidden_units,
feature_columns=[
feature_column.numeric_column('age'),
feature_column.numeric_column('height')
],
optimizer=mock_optimizer(self, hidden_units))
with monitored_session.MonitoredTrainingSession(
checkpoint_dir=self._model_dir) as sess:
if mode == model_fn.ModeKeys.TRAIN:
sess.run(estimator_spec.train_op)
elif mode == model_fn.ModeKeys.EVAL:
sess.run(estimator_spec.loss)
elif mode == model_fn.ModeKeys.PREDICT:
sess.run(estimator_spec.predictions)
else:
self.fail('Invalid mode: {}'.format(mode))
开发者ID:ajaybhat,项目名称:tensorflow,代码行数:52,代码来源:dnn_testing_utils.py
示例8: test_reads_before_increments
def test_reads_before_increments(self):
with ops.Graph().as_default():
training_util.create_global_step()
read_tensor = training_util._get_or_create_global_step_read()
inc_op = training_util._increment_global_step(1)
inc_three_op = training_util._increment_global_step(3)
with monitored_session.MonitoredTrainingSession() as sess:
read_value, _ = sess.run([read_tensor, inc_op])
self.assertEqual(0, read_value)
read_value, _ = sess.run([read_tensor, inc_three_op])
self.assertEqual(1, read_value)
read_value = sess.run(read_tensor)
self.assertEqual(4, read_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:13,代码来源:training_util_test.py
示例9: create_checkpoint
def create_checkpoint(rnn_weights, rnn_biases, logits_weights, logits_biases,
global_step, model_dir):
"""Create checkpoint file with provided model weights.
Args:
rnn_weights: Iterable of values of weights for the RNN cell.
rnn_biases: Iterable of values of biases for the RNN cell.
logits_weights: Iterable of values for matrix connecting RNN output to
logits.
logits_biases: Iterable of values for logits bias term.
global_step: Initial global step to save in checkpoint.
model_dir: Directory into which checkpoint is saved.
"""
model_weights = {}
model_weights[CELL_WEIGHTS_NAME] = rnn_weights
model_weights[CELL_BIAS_NAME] = rnn_biases
model_weights[LOGITS_WEIGHTS_NAME] = logits_weights
model_weights[LOGITS_BIAS_NAME] = logits_biases
with ops.Graph().as_default():
# Create model variables.
for k, v in six.iteritems(model_weights):
variables_lib.Variable(v, name=k, dtype=dtypes.float32)
# Create non-model variables.
global_step_var = training_util.create_global_step()
assign_op = global_step_var.assign(global_step)
# Initialize vars and save checkpoint.
with monitored_session.MonitoredTrainingSession(
checkpoint_dir=model_dir) as sess:
sess.run(assign_op)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:32,代码来源:rnn_test.py
示例10: _create_checkpoint
def _create_checkpoint(weights_and_biases, global_step, model_dir):
"""Create checkpoint file with provided model weights.
Args:
weights_and_biases: Iterable of tuples of weight and bias values.
global_step: Initial global step to save in checkpoint.
model_dir: Directory into which checkpoint is saved.
"""
weights, biases = zip(*weights_and_biases)
model_weights = {}
# Hidden layer weights.
for i in range(0, len(weights) - 1):
model_weights[_HIDDEN_WEIGHTS_NAME_PATTERN % i] = weights[i]
model_weights[_HIDDEN_BIASES_NAME_PATTERN % i] = biases[i]
# Output layer weights.
model_weights[_LOGITS_WEIGHTS_NAME] = weights[-1]
model_weights[_LOGITS_BIASES_NAME] = biases[-1]
with ops.Graph().as_default():
# Create model variables.
for k, v in six.iteritems(model_weights):
variables_lib.Variable(v, name=k, dtype=dtypes.float32)
# Create non-model variables.
global_step_var = training_util.create_global_step()
# Initialize vars and save checkpoint.
with tf_session.Session() as sess:
variables_lib.global_variables_initializer().run()
global_step_var.assign(global_step).eval()
saver.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))
开发者ID:cameronphchen,项目名称:tensorflow,代码行数:33,代码来源:dnn_test.py
示例11: test_stop
def test_stop(self):
hook = early_stopping._StopOnPredicateHook(
should_stop_fn=lambda: False, run_every_secs=0)
with ops.Graph().as_default():
training_util.create_global_step()
no_op = control_flow_ops.no_op()
with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
self.assertFalse(mon_sess.raw_session().run(hook._stop_var))
hook = early_stopping._StopOnPredicateHook(
should_stop_fn=lambda: True, run_every_secs=0)
with ops.Graph().as_default():
training_util.create_global_step()
no_op = control_flow_ops.no_op()
with monitored_session.SingularMonitoredSession(hooks=[hook]) as mon_sess:
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
self.assertTrue(mon_sess.raw_session().run(hook._stop_var))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:20,代码来源:early_stopping_test.py
示例12: global_step
def global_step(self):
if self._global_step is None:
# Get the default create_global_step utility to actually call
# self.add_variable, by setting a custom getter.
def _owned_variable_as_custom_getter(getter, *args, **kwargs):
return self.add_variable(*args, getter=getter, **kwargs)
with variable_scope.variable_scope(
"", custom_getter=_owned_variable_as_custom_getter):
self._global_step = training_util.create_global_step()
return self._global_step
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:11,代码来源:checkpointable_test.py
示例13: _gan_train_ops
def _gan_train_ops(self, generator_add, discriminator_add):
step = training_util.create_global_step()
# Increment the global count every time a train op is run so we can count
# the number of times they're run.
# NOTE: `use_locking=True` is required to avoid race conditions with
# joint training.
train_ops = namedtuples.GANTrainOps(
generator_train_op=step.assign_add(generator_add, use_locking=True),
discriminator_train_op=step.assign_add(discriminator_add,
use_locking=True),
global_step_inc_op=step.assign_add(1))
return train_ops
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:12,代码来源:train_test.py
示例14: testGlobalStepIsWrapped
def testGlobalStepIsWrapped(self):
distribution = parameter_server_strategy.ParameterServerStrategy(
num_gpus_per_worker=2)
with ops.Graph().as_default(), distribution.scope():
created_step = training_util.create_global_step()
get_step = training_util.get_global_step()
self.assertEqual(created_step, get_step,
msg=('created_step %s type %s vs. get_step %s type %s' %
(id(created_step), created_step.__class__.__name__,
id(get_step), get_step.__class__.__name__)))
self.assertIs(values.AggregatingVariable, type(created_step))
self.assertIs(values.AggregatingVariable, type(get_step))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:12,代码来源:parameter_server_strategy_test.py
示例15: _save_first_checkpoint
def _save_first_checkpoint(keras_model, custom_objects, config):
"""Save first checkpoint for the keras Estimator.
Args:
keras_model: an instance of compiled keras model.
custom_objects: Dictionary for custom objects.
config: Estimator config.
Returns:
The path where keras model checkpoint is saved.
"""
# save checkpoint into subdirectory to allow warm start
keras_model_dir = os.path.join(config.model_dir, 'keras')
# Load weights and save to checkpoint if there is no checkpoint
latest_path = checkpoint_management.latest_checkpoint(keras_model_dir)
if not latest_path:
keras_weights = None
if _any_weight_initialized(keras_model):
keras_weights = keras_model.get_weights()
if not gfile.IsDirectory(keras_model_dir):
gfile.MakeDirs(keras_model_dir)
with ops.Graph().as_default():
random_seed.set_random_seed(config.tf_random_seed)
training_util.create_global_step()
model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model,
custom_objects)
# save to checkpoint
with session.Session(config=config.session_config) as sess:
if keras_weights:
model.set_weights(keras_weights)
# Make update ops and initialize all variables.
if not model.train_function:
# pylint: disable=protected-access
model._make_train_function()
K._initialize_variables(sess)
# pylint: enable=protected-access
saver = saver_lib.Saver()
latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')
saver.save(sess, latest_path)
return latest_path
开发者ID:AnishShah,项目名称:tensorflow,代码行数:40,代码来源:keras.py
示例16: testGlobalStepIsWrappedOnTwoGPUs
def testGlobalStepIsWrappedOnTwoGPUs(self, use_core_strategy):
strategy, _, _ = create_test_objects(
num_gpus=2, use_core_strategy=use_core_strategy)
with ops.Graph().as_default(), strategy.scope():
created_step = training_util.create_global_step()
get_step = training_util.get_global_step()
self.assertEqual(created_step, get_step,
msg=('created_step %s type %s vs. get_step %s type %s' %
(id(created_step), created_step.__class__.__name__,
id(get_step), get_step.__class__.__name__)))
self.assertIs(values.AggregatingVariable, type(created_step))
self.assertIs(values.AggregatingVariable, type(get_step))
self.assertIs(strategy, created_step.distribute_strategy)
开发者ID:hdyen,项目名称:tensorflow,代码行数:13,代码来源:parameter_server_strategy_test.py
示例17: test_requests
def test_requests(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
training_util.create_global_step()
mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor()
hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
hook.begin()
mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
a_tensor = constant_op.constant([0], name='a_tensor')
constant_op.constant([5], name='another_tensor')
constant_op.constant([10], name='third_tensor')
mock_mon.requested_tensors = ['another_tensor']
mock_mon2.requested_tensors = ['third_tensor']
sess.run(variables.global_variables_initializer())
output = mon_sess.run(a_tensor)
self.assertEqual(output, [0])
self.assertEqual(mock_mon.output['another_tensor'], [5])
self.assertEqual(mock_mon2.output['third_tensor'], [10])
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:22,代码来源:monitors_test.py
示例18: testGlobalStepIsNotWrappedOnOneGPU
def testGlobalStepIsNotWrappedOnOneGPU(self, use_core_strategy):
strategy, _, _ = create_test_objects(
num_gpus=1, use_core_strategy=use_core_strategy)
with ops.Graph().as_default(), strategy.scope():
created_step = training_util.create_global_step()
get_step = training_util.get_global_step()
self.assertEqual(created_step, get_step,
msg=('created_step %s type %s vs. get_step %s type %s' %
(id(created_step), created_step.__class__.__name__,
id(get_step), get_step.__class__.__name__)))
self.assertIs(resource_variable_ops.ResourceVariable, type(created_step))
self.assertIs(resource_variable_ops.ResourceVariable, type(get_step))
self.assertIs(strategy, created_step.distribute_strategy)
开发者ID:pyjennings,项目名称:tensorflow,代码行数:13,代码来源:parameter_server_strategy_test.py
示例19: create_global_step
def create_global_step(graph=None):
"""Create global step tensor in graph.
Args:
graph: The graph in which to create the global step tensor. If missing,
use default graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step tensor is already defined.
"""
return training_util.create_global_step(graph)
开发者ID:Immexxx,项目名称:tensorflow,代码行数:14,代码来源:variables.py
示例20: create_global_step
def create_global_step(graph=None):
"""Create global step tensor in graph.
This API is deprecated. Use core framework training version instead.
Args:
graph: The graph in which to create the global step tensor. If missing,
use default graph.
Returns:
Global step tensor.
Raises:
ValueError: if global step tensor is already defined.
"""
return training_util.create_global_step(graph)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:16,代码来源:variables.py
注:本文中的tensorflow.python.training.training_util.create_global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论