本文整理汇总了Python中tensorflow.contrib.framework.python.ops.variables.create_global_step函数的典型用法代码示例。如果您正苦于以下问题:Python create_global_step函数的具体用法?Python create_global_step怎么用?Python create_global_step使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了create_global_step函数的17个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testLinearRegression
def testLinearRegression(self):
my_seed = 42
config = run_config.RunConfig(tf_random_seed=my_seed)
boston = base.load_boston()
columns = [feature_column.real_valued_column('', dimension=13)]
# We train with
with ops.Graph().as_default() as g1:
random.seed(my_seed)
g1.seed = my_seed
variables.create_global_step()
regressor1 = linear.LinearRegressor(
optimizer=_NULL_OPTIMIZER, feature_columns=columns, config=config)
regressor1.fit(x=boston.data, y=boston.target, steps=1)
with ops.Graph().as_default() as g2:
random.seed(my_seed)
g2.seed = my_seed
variables.create_global_step()
regressor2 = linear.LinearRegressor(
optimizer=_NULL_OPTIMIZER, feature_columns=columns, config=config)
regressor2.fit(x=boston.data, y=boston.target, steps=1)
self.assertAllClose(regressor1.weights_, regressor2.weights_)
self.assertAllClose(regressor1.bias_, regressor2.bias_)
self.assertAllClose(
list(regressor1.predict_scores(
boston.data, as_iterable=True)),
list(regressor2.predict_scores(
boston.data, as_iterable=True)),
atol=1e-05)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:32,代码来源:stability_test.py
示例2: export_estimator
def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
signature_fn=_generic_signature_fn, default_batch_size=1,
exports_to_keep=None):
"""Exports inference graph into given dir.
Args:
estimator: Estimator to export
export_dir: A string containing a directory to write the exported graph
and checkpoints.
input_fn: Function that given `Tensor` of `Example` strings, parses it into
features that are then passed to the model.
signature_fn: Function that given `Tensor` of `Example` strings,
`dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
and returns default and named exporting signautres.
default_batch_size: Default batch size of the `Example` placeholder.
exports_to_keep: Number of exports to keep.
"""
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
predictions = estimator._get_predict_ops(features)
default_signature, named_graph_signatures = signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
开发者ID:363158858,项目名称:tensorflow,代码行数:33,代码来源:export.py
示例3: _export_estimator
def _export_estimator(estimator,
export_dir,
signature_fn,
input_fn,
default_batch_size,
exports_to_keep):
input_fn = input_fn or _default_input_fn
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
predictions = estimator._get_predict_ops(features)
# Explicit signature_fn takes priority
if signature_fn:
default_signature, named_graph_signatures = signature_fn(examples,
features,
predictions)
else:
try:
# Some estimators provide a target_column of known type
target_column = estimator._get_target_column()
problem_type = target_column.problem_type
if problem_type == layers.ProblemType.CLASSIFICATION:
signature_fn = classification_signature_fn
elif problem_type == layers.ProblemType.LINEAR_REGRESSION:
signature_fn = regression_signature_fn
elif problem_type == layers.ProblemType.LOGISTIC_REGRESSION:
signature_fn = logistic_regression_signature_fn
else:
raise ValueError(
'signature_fn must be provided because the TargetColumn is a %s, '
'which does not have a standard problem type and so cannot use a '
'standard export signature.' % type(target_column).__name__)
default_signature, named_graph_signatures = (
signature_fn(examples, features, predictions))
except AttributeError:
logging.warn(
'Change warning: `signature_fn` will be required after'
'2016-08-01.\n'
'Using generic signatures for now. To maintain this behavior, '
'pass:\n'
' signature_fn=export.generic_signature_fn\n'
'Also consider passing a regression or classification signature; '
'see cl/126430915 for an example.')
default_signature, named_graph_signatures = generic_signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:58,代码来源:export.py
示例4: test_create_global_step
def test_create_global_step(self):
self.assertEquals(None, variables_lib2.get_global_step())
with ops.Graph().as_default() as g:
global_step = variables_lib2.create_global_step()
self._assert_global_step(global_step)
self.assertRaisesRegexp(ValueError, 'already exists',
variables_lib2.create_global_step)
self.assertRaisesRegexp(ValueError, 'already exists',
variables_lib2.create_global_step, g)
self._assert_global_step(variables_lib2.create_global_step(ops.Graph()))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:10,代码来源:variables_test.py
示例5: setUp
def setUp(self):
test.TestCase.setUp(self)
self.log_dir = 'log/dir'
self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
tensor = state_ops.assign_add(var, 1.0)
self.summary_op = summary_lib.scalar('my_summary', tensor)
with variable_scope.variable_scope('foo', use_resource=True):
variables.create_global_step()
self.train_op = training_util._increment_global_step(1)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:13,代码来源:basic_session_run_hooks_test.py
示例6: export_estimator
def export_estimator(estimator,
export_dir,
signature_fn=None,
input_fn=_default_input_fn,
default_batch_size=1,
exports_to_keep=None):
"""Exports inference graph into given dir.
Args:
estimator: Estimator to export
export_dir: A string containing a directory to write the exported graph
and checkpoints.
signature_fn: Function that given `Tensor` of `Example` strings,
`dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
input_fn: Function that given `Tensor` of `Example` strings, parses it into
features that are then passed to the model.
and returns default and named exporting signatures.
default_batch_size: Default batch size of the `Example` placeholder.
exports_to_keep: Number of exports to keep.
"""
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
predictions = estimator._get_predict_ops(features)
if signature_fn:
default_signature, named_graph_signatures = signature_fn(examples,
features,
predictions)
else:
logging.warn(
'Change warning: `signature_fn` will be required after 2016-08-01.\n'
'Using generic signatures for now. To maintain this behavior, '
'pass:\n'
' signature_fn=export.generic_signature_fn\n'
'Also consider passing a regression or classification signature; see '
'cl/126430915 for an example.')
default_signature, named_graph_signatures = generic_signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
开发者ID:10imaging,项目名称:tensorflow,代码行数:48,代码来源:export.py
示例7: test_evaluate_ready_for_local_init
def test_evaluate_ready_for_local_init(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
w = variables.Variable(
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
ops.add_to_collection(ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
ready_for_local_init_op)
_ = learn.graph_actions.evaluate(
g,
output_dir=self._output_dir,
checkpoint_path=None,
eval_dict={'a': v},
max_steps=1)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:16,代码来源:graph_actions_test.py
示例8: test_train_worker_monitor
def test_train_worker_monitor(self):
# We need to explicitly set device due to check on non-chief workers
# requiring all variables to have a device assigned.
with ops.Graph().as_default() as g, g.device('/cpu:0'):
global_step = variables_lib.create_global_step(g)
train_op = state_ops.assign_add(global_step, 1)
loss_op = constant_op.constant(2.0)
summary.scalar('loss', loss_op)
# Add explicit "local" init op to initialize all variables
# as there's no chief to init here.
init_op = variables.global_variables_initializer()
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, init_op)
# Create worker monitors where one should be active on the worker
# and the other chief exclusive.
chief_exclusive_monitor = _BaseMonitorWrapper(False)
all_workers_monitor = _BaseMonitorWrapper(True)
with self.test_session(g):
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
global_step_tensor=global_step,
train_op=train_op,
loss_op=loss_op,
supervisor_is_chief=False,
steps=1,
monitors=[chief_exclusive_monitor, all_workers_monitor])
self.assertEqual(2.0, loss)
self.assertTrue(not chief_exclusive_monitor.is_active and
all_workers_monitor.is_active,
'Only non-chief runnable monitor must have been active.')
self.assertTrue(not chief_exclusive_monitor.has_step and
all_workers_monitor.has_step,
'Only non-chief runnable monitor must have a step.')
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:33,代码来源:graph_actions_test.py
示例9: testDNNRegression
def testDNNRegression(self):
my_seed = 42
config = run_config.RunConfig(tf_random_seed=my_seed)
boston = base.load_boston()
columns = [feature_column.real_valued_column('', dimension=13)]
with ops.Graph().as_default() as g1:
random.seed(my_seed)
g1.seed = my_seed
variables.create_global_step()
regressor1 = dnn.DNNRegressor(
hidden_units=[10],
feature_columns=columns,
optimizer=_NULL_OPTIMIZER,
config=config)
regressor1.fit(x=boston.data, y=boston.target, steps=1)
with ops.Graph().as_default() as g2:
random.seed(my_seed)
g2.seed = my_seed
variables.create_global_step()
regressor2 = dnn.DNNRegressor(
hidden_units=[10],
feature_columns=columns,
optimizer=_NULL_OPTIMIZER,
config=config)
regressor2.fit(x=boston.data, y=boston.target, steps=1)
weights1 = ([regressor1.get_variable_value('dnn/hiddenlayer_0/weights')] +
[regressor1.get_variable_value('dnn/logits/weights')])
weights2 = ([regressor2.get_variable_value('dnn/hiddenlayer_0/weights')] +
[regressor2.get_variable_value('dnn/logits/weights')])
for w1, w2 in zip(weights1, weights2):
self.assertAllClose(w1, w2)
biases1 = ([regressor1.get_variable_value('dnn/hiddenlayer_0/biases')] +
[regressor1.get_variable_value('dnn/logits/biases')])
biases2 = ([regressor2.get_variable_value('dnn/hiddenlayer_0/biases')] +
[regressor2.get_variable_value('dnn/logits/biases')])
for b1, b2 in zip(biases1, biases2):
self.assertAllClose(b1, b2)
self.assertAllClose(
list(regressor1.predict_scores(
boston.data, as_iterable=True)),
list(regressor2.predict_scores(
boston.data, as_iterable=True)),
atol=1e-05)
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:47,代码来源:stability_test.py
示例10: _build_inference_graph
def _build_inference_graph(self):
"""Build simple inference graph.
This includes a regular variable, local variable, and fake table.
Returns:
Tuple of 3 `Tensor` objects, 2 input and 1 output.
"""
variables_lib.create_global_step()
in0 = variables.Variable(1.0)
in1 = variables_lib.local_variable(2.0)
fake_table = variables.Variable(
3.0,
trainable=False,
collections=['fake_tables'],
name='fake_table_var')
in0.graph.add_to_collections([ops.GraphKeys.TABLE_INITIALIZERS],
fake_table.initializer)
out = in0 + in1 + fake_table
return in0, in1, out
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:20,代码来源:graph_actions_test.py
示例11: test_train_loss
def test_train_loss(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
state_ops.assign_add(variables_lib.get_global_step(), 1),
state_ops.assign_add(loss_var, -1.0))
self._assert_summaries(self._output_dir)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions.train(
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_var.value(),
steps=6)
# TODO(ebrevdo,ptucker,ispir): this meta_graph_def lacks the
# SaverDef, so we can't add it to the summary assertion test below.
# meta_graph_def = meta_graph.create_meta_graph_def()
self.assertEqual(4.0, loss)
self._assert_summaries(self._output_dir, expected_graphs=[g])
self._assert_ckpt(self._output_dir, True)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:21,代码来源:graph_actions_test.py
示例12: test_requests
def test_requests(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
variables_lib.create_global_step()
mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor()
hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
hook.begin()
mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
a_tensor = constant_op.constant([0], name='a_tensor')
constant_op.constant([5], name='another_tensor')
constant_op.constant([10], name='third_tensor')
mock_mon.requested_tensors = ['another_tensor']
mock_mon2.requested_tensors = ['third_tensor']
sess.run(variables.global_variables_initializer())
output = mon_sess.run(a_tensor)
self.assertEqual(output, [0])
self.assertEqual(mock_mon.output['another_tensor'], [5])
self.assertEqual(mock_mon2.output['third_tensor'], [10])
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:22,代码来源:monitors_test.py
示例13: test_calls_and_steps
def test_calls_and_steps(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
global_step_tensor = variables_lib.create_global_step()
inc_5 = state_ops.assign_add(global_step_tensor, 5)
mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor()
hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
hook.begin()
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.call_counter['begin'], 1)
sess.run(variables.global_variables_initializer())
sess.run(global_step_tensor.assign(10))
mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
mon_sess.run(inc_5)
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.output, {})
self.assertEqual(mon.last_begin_step, 11)
self.assertEqual(mon.last_end_step, 11)
self.assertEqual(mon.last_post_step, 11)
self.assertEqual(mon.call_counter['step_end'], 1)
self.assertEqual(mon.call_counter['step_begin'], 1)
self.assertEqual(mon.call_counter['post_step'], 1)
mon_sess.run(inc_5)
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.output, {})
self.assertEqual(mon.last_begin_step, 16)
self.assertEqual(mon.last_end_step, 16)
self.assertEqual(mon.last_post_step, 16)
self.assertEqual(mon.call_counter['step_end'], 2)
self.assertEqual(mon.call_counter['step_begin'], 2)
self.assertEqual(mon.call_counter['post_step'], 2)
hook.end(sess)
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.call_counter['end'], 1)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:40,代码来源:monitors_test.py
示例14: _export_estimator
def _export_estimator(estimator,
export_dir,
signature_fn,
input_fn,
default_batch_size,
exports_to_keep,
input_feature_key=None,
use_deprecated_input_fn=True,
prediction_key=None,
checkpoint_path=None):
if use_deprecated_input_fn:
input_fn = input_fn or _default_input_fn
elif input_fn is None:
raise ValueError('input_fn must be defined.')
# If checkpoint_path is specified, use the specified checkpoint path.
checkpoint_path = (checkpoint_path or
tf_saver.latest_checkpoint(estimator._model_dir))
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
if use_deprecated_input_fn:
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
else:
features, _ = input_fn()
examples = None
if input_feature_key is not None:
examples = features.pop(input_feature_key)
if (not features) and (examples is None):
raise ValueError('Either features or examples must be defined.')
predictions = estimator._get_predict_ops(features).predictions
if prediction_key is not None:
predictions = predictions[prediction_key]
# Explicit signature_fn takes priority
if signature_fn:
default_signature, named_graph_signatures = signature_fn(examples,
features,
predictions)
else:
try:
# Some estimators provide a signature function.
# TODO(zakaria): check if the estimator has this function,
# raise helpful error if not
signature_fn = estimator._create_signature_fn()
default_signature, named_graph_signatures = (
signature_fn(examples, features, predictions))
except AttributeError:
logging.warn(
'Change warning: `signature_fn` will be required after'
'2016-08-01.\n'
'Using generic signatures for now. To maintain this behavior, '
'pass:\n'
' signature_fn=export.generic_signature_fn\n'
'Also consider passing a regression or classification signature; '
'see cl/126430915 for an example.')
default_signature, named_graph_signatures = generic_signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
return _export_graph(
g,
_get_saver(),
checkpoint_path,
export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:75,代码来源:export.py
示例15: export_fn
def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None):
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
input_ops = feature_transforms.build_csv_serving_tensors_for_training_step(
args.analysis, features, schema, stats, keep_target)
model_fn_ops = estimator._call_model_fn(input_ops.features,
None,
model_fn_lib.ModeKeys.INFER)
output_fetch_tensors = make_prediction_output_tensors(
args=args,
features=features,
input_ops=input_ops,
model_fn_ops=model_fn_ops,
keep_target=keep_target)
# Don't use signature_def_utils.predict_signature_def as that renames
# tensor names if there is only 1 input/output tensor!
signature_inputs = {key: tf.saved_model.utils.build_tensor_info(tensor)
for key, tensor in six.iteritems(input_ops.default_inputs)}
signature_outputs = {key: tf.saved_model.utils.build_tensor_info(tensor)
for key, tensor in six.iteritems(output_fetch_tensors)}
signature_def_map = {
'serving_default':
signature_def_utils.build_signature_def(
signature_inputs,
signature_outputs,
tf.saved_model.signature_constants.PREDICT_METHOD_NAME)}
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
if not checkpoint_path:
raise ValueError("Couldn't find trained model at %s."
% estimator._model_dir)
export_dir = saved_model_export_utils.get_timestamped_export_dir(
export_dir_base)
if (model_fn_ops.scaffold is not None and
model_fn_ops.scaffold.saver is not None):
saver_for_restore = model_fn_ops.scaffold.saver
else:
saver_for_restore = saver.Saver(sharded=True)
with tf_session.Session('') as session:
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
tf.tables_initializer())
# Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
session, [tag_constants.SERVING],
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=init_op)
builder.save(False)
# Add the extra assets
if assets_extra:
assets_extra_path = os.path.join(compat.as_bytes(export_dir),
compat.as_bytes('assets.extra'))
for dest_relative, source in assets_extra.items():
dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
compat.as_bytes(dest_relative))
dest_path = os.path.dirname(dest_absolute)
file_io.recursive_create_dir(dest_path)
file_io.copy(source, dest_absolute)
# only keep the last 3 models
saved_model_export_utils.garbage_collect_exports(
export_dir_base,
exports_to_keep=3)
# save the last model to the model folder.
# export_dir_base = A/B/intermediate_models/
if keep_target:
final_dir = os.path.join(args.job_dir, 'evaluation_model')
else:
final_dir = os.path.join(args.job_dir, 'model')
if file_io.is_directory(final_dir):
file_io.delete_recursively(final_dir)
file_io.recursive_create_dir(final_dir)
recursive_copy(export_dir, final_dir)
return export_dir
开发者ID:javiervicho,项目名称:pydatalab,代码行数:90,代码来源:task.py
示例16: export_fn
def export_fn(estimator, export_dir_base, checkpoint_path=None, eval_result=None):
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
input_ops = serving_from_csv_input(train_config, args, keep_target)
model_fn_ops = estimator._call_model_fn(input_ops.features,
None,
model_fn_lib.ModeKeys.INFER)
output_fetch_tensors = make_output_tensors(
train_config=train_config,
args=args,
input_ops=input_ops,
model_fn_ops=model_fn_ops,
keep_target=keep_target)
signature_def_map = {
'serving_default': signature_def_utils.predict_signature_def(input_ops.default_inputs,
output_fetch_tensors)
}
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(estimator._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% estimator._model_dir)
export_dir = saved_model_export_utils.get_timestamped_export_dir(
export_dir_base)
with tf_session.Session('') as session:
# variables.initialize_local_variables()
variables.local_variables_initializer()
data_flow_ops.tables_initializer()
saver_for_restore = saver.Saver(
variables.global_variables(),
sharded=True)
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
data_flow_ops.tables_initializer())
# Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(
session, [tag_constants.SERVING],
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=init_op)
builder.save(False)
# Add the extra assets
if assets_extra:
assets_extra_path = os.path.join(compat.as_bytes(export_dir),
compat.as_bytes('assets.extra'))
for dest_relative, source in assets_extra.items():
dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
compat.as_bytes(dest_relative))
dest_path = os.path.dirname(dest_absolute)
gfile.MakeDirs(dest_path)
gfile.Copy(source, dest_absolute)
# only keep the last 3 models
saved_model_export_utils.garbage_collect_exports(
python_portable_string(export_dir_base),
exports_to_keep=3)
# save the last model to the model folder.
# export_dir_base = A/B/intermediate_models/
if keep_target:
final_dir = os.path.join(args.job_dir, 'evaluation_model')
else:
final_dir = os.path.join(args.job_dir, 'model')
if file_io.is_directory(final_dir):
file_io.delete_recursively(final_dir)
file_io.recursive_create_dir(final_dir)
_recursive_copy(export_dir, final_dir)
return export_dir
开发者ID:parthea,项目名称:pydatalab,代码行数:81,代码来源:util.py
示例17: _export_estimator
def _export_estimator(estimator,
export_dir,
signature_fn,
input_fn,
default_batch_size,
exports_to_keep,
input_feature_key=None,
use_deprecated_input_fn=True,
prediction_key=None):
if use_deprecated_input_fn:
input_fn = input_fn or _default_input_fn
elif input_fn is None:
raise ValueError('input_fn must be defined.')
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
if use_deprecated_input_fn:
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
else:
features, _ = input_fn()
examples = None
if input_feature_key is not None:
examples = features.pop(input_feature_key)
if (not features) and (examples is None):
raise ValueError('Either features or examples must be defined.')
# The default return type of _get_predict_ops is ModelFnOps. But there are
# some subclasses of tf.contrib.learn.Estimator which override this
# method and use the legacy signature, namely _get_predict_ops returns a
# `predictions` Tensor or dict or Tensors. The following else-statement
# code covers these cases, but will soon be deleted after the subclasses
# are updated.
# TODO(b/32664904): Update subclasses and delete the else-statement.
infer_ops = estimator._get_predict_ops(features)
if isinstance(infer_ops, model_fn.ModelFnOps): # Default signature
predictions = infer_ops.predictions
else: # Legacy signature
predictions = infer_ops
if prediction_key is not None:
predictions = predictions[prediction_key]
# Explicit signature_fn takes priority
if signature_fn:
default_signature, named_graph_signatures = signature_fn(examples,
features,
predictions)
else:
try:
# Some estimators provide a signature function.
# TODO(zakaria): check if the estimator has this function,
# raise helpful error if not
signature_fn = estimator._create_signature_fn()
default_signature, named_graph_signatures = (
signature_fn(examples, features, predictions))
except AttributeError:
logging.warn(
'Change warning: `signature_fn` will be required after'
'2016-08-01.\n'
'Using generic signatures for now. To maintain this behavior, '
'pass:\n'
' signature_fn=export.generic_signature_fn\n'
'Also consider passing a regression or classification signature; '
'see cl/126430915 for an example.')
default_signature, named_graph_signatures = generic_signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
return _export_graph(
g,
_get_saver(),
checkpoint_path,
export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
开发者ID:Y-owen,项目名称:tensorflow,代码行数:83,代码来源:export.py
注:本文中的tensorflow.contrib.framework.python.ops.variables.create_global_step函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论