本文整理汇总了Python中tensorflow.python.training.saver.latest_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python latest_checkpoint函数的具体用法?Python latest_checkpoint怎么用?Python latest_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了latest_checkpoint函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testRecoverSession
def testRecoverSession(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
try:
gfile.DeleteRecursively(checkpoint_dir)
except errors.OpError:
pass # Ignore
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
v = variables.Variable(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
sess, initialized = sm.recover_session(
"", saver=saver, checkpoint_dir=checkpoint_dir)
self.assertFalse(initialized)
sess.run(v.initializer)
self.assertEquals(1, sess.run(v))
saver.save(sess,
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
self._test_recovered_variable(
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
checkpoint_dir))
# Cannot set both checkpoint_dir and checkpoint_filename_with_path.
with self.assertRaises(ValueError):
self._test_recovered_variable(
checkpoint_dir=checkpoint_dir,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
checkpoint_dir))
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:31,代码来源:session_manager_test.py
示例2: _read_config_files
def _read_config_files(self, run_paths):
configs = {}
config_fpaths = {}
for run_name, logdir in run_paths.items():
config_fpath = os.path.join(logdir, PROJECTOR_FILENAME)
if not file_io.file_exists(config_fpath):
# Skip runs that have no config file.
continue
# Read the config file.
file_content = file_io.read_file_to_string(config_fpath).decode('utf-8')
config = ProjectorConfig()
text_format.Merge(file_content, config)
if not config.model_checkpoint_path:
# See if you can find a checkpoint file in the logdir.
ckpt_path = latest_checkpoint(logdir)
if not ckpt_path:
# Or in the parent of logdir.
ckpt_path = latest_checkpoint(os.path.join('../', logdir))
if not ckpt_path:
logging.warning('Cannot find model checkpoint in %s', logdir)
continue
config.model_checkpoint_path = ckpt_path
# Sanity check for the checkpoint file.
if not file_io.file_exists(config.model_checkpoint_path):
logging.warning('Checkpoint file %s not found',
config.model_checkpoint_path)
continue
configs[run_name] = config
config_fpaths[run_name] = config_fpath
return configs, config_fpaths
开发者ID:KalraA,项目名称:tensorflow,代码行数:32,代码来源:plugin.py
示例3: _find_latest_checkpoint
def _find_latest_checkpoint(dir_path):
try:
ckpt_path = latest_checkpoint(dir_path)
if not ckpt_path:
# Check the parent directory.
ckpt_path = latest_checkpoint(os.path.join(dir_path, os.pardir))
return ckpt_path
except errors.NotFoundError:
return None
开发者ID:chenjun0210,项目名称:tensorflow,代码行数:9,代码来源:projector_plugin.py
示例4: testMultiEvalStepIncrements
def testMultiEvalStepIncrements(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
# Train a model for a single step to get a checkpoint.
self._train_model(checkpoint_dir, num_steps=1)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
# Create the model so we have something to restore.
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
logistic_classifier(inputs)
num_evals = 6
my_var = local_variable(0.0, name='MyVar')
# In eval ops, we also increase the eval step one more time.
eval_ops = [state_ops.assign_add(my_var, 1.0),
state_ops.assign_add(
evaluation._get_or_create_eval_step(), 1, use_locking=True)]
expect_eval_update_counts = num_evals // 2
final_ops = array_ops.identity(my_var)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=eval_ops,
final_ops={'value': final_ops},
hooks=[evaluation._StopAfterNEvalsHook(num_evals),])
self.assertEqual(final_ops_values['value'], expect_eval_update_counts)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:28,代码来源:evaluation_test.py
示例5: testEvalOpAndFinalOp
def testEvalOpAndFinalOp(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
# Train a model for a single step to get a checkpoint.
self._train_model(checkpoint_dir, num_steps=1)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
# Create the model so we have something to restore.
inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
logistic_classifier(inputs)
num_evals = 5
final_increment = 9.0
my_var = local_variable(0.0, name='MyVar')
eval_ops = state_ops.assign_add(my_var, 1.0)
final_ops = array_ops.identity(my_var) + final_increment
final_hooks = [evaluation._StopAfterNEvalsHook(num_evals),]
initial_hooks = list(final_hooks)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=eval_ops,
final_ops={'value': final_ops},
hooks=final_hooks)
self.assertEqual(final_ops_values['value'], num_evals + final_increment)
self.assertEqual(initial_hooks, final_hooks)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:27,代码来源:evaluation_test.py
示例6: testEvaluateWithFiniteInputs
def testEvaluateWithFiniteInputs(self):
checkpoint_dir = os.path.join(self.get_temp_dir(),
'evaluate_with_finite_inputs')
# Train a Model to completion:
self._train_model(checkpoint_dir, num_steps=300)
# Run evaluation. Inputs are fed through input producer for one epoch.
all_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
all_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
single_input, single_label = training.slice_input_producer(
[all_inputs, all_labels], num_epochs=1)
inputs, labels = training.batch([single_input, single_label], batch_size=6,
allow_smaller_final_batch=True)
logits = logistic_classifier(inputs)
predictions = math_ops.round(logits)
accuracy, update_op = metrics.accuracy(
predictions=predictions, labels=labels)
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
final_ops_values = evaluation._evaluate_once(
checkpoint_path=checkpoint_path,
eval_ops=update_op,
final_ops={'accuracy': accuracy,
'eval_steps': evaluation._get_or_create_eval_step()},
hooks=[evaluation._StopAfterNEvalsHook(None),])
self.assertTrue(final_ops_values['accuracy'] > .99)
# Runs evaluation for 4 iterations. First 2 evaluate full batch of 6 inputs
# each; the 3rd iter evaluates the remaining 4 inputs, and the last one
# triggers an error which stops evaluation.
self.assertEqual(final_ops_values['eval_steps'], 4)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:35,代码来源:evaluation_test.py
示例7: test_recovery
def test_recovery(self):
logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
with ops.Graph().as_default():
gstep = variables_lib.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold()
# Use a hook to save the model every 100 steps. It also saves it at
# the end.
hooks = [
basic_session_run_hooks.CheckpointSaverHook(
logdir, save_steps=1, scaffold=scaffold)
]
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir),
hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir)) as session:
self.assertEqual(2, session.run(gstep))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
checkpoint_filename_with_path=saver_lib.latest_checkpoint(
logdir))) as session:
self.assertEqual(2, session.run(gstep))
开发者ID:kadeng,项目名称:tensorflow,代码行数:31,代码来源:monitored_session_test.py
示例8: export_estimator
def export_estimator(estimator, export_dir, input_fn=_default_input_fn,
signature_fn=_generic_signature_fn, default_batch_size=1,
exports_to_keep=None):
"""Exports inference graph into given dir.
Args:
estimator: Estimator to export
export_dir: A string containing a directory to write the exported graph
and checkpoints.
input_fn: Function that given `Tensor` of `Example` strings, parses it into
features that are then passed to the model.
signature_fn: Function that given `Tensor` of `Example` strings,
`dict` of `Tensor`s for features and `dict` of `Tensor`s for predictions
and returns default and named exporting signautres.
default_batch_size: Default batch size of the `Example` placeholder.
exports_to_keep: Number of exports to keep.
"""
checkpoint_path = tf_saver.latest_checkpoint(estimator._model_dir)
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
examples = array_ops.placeholder(dtype=dtypes.string,
shape=[default_batch_size],
name='input_example_tensor')
features = input_fn(estimator, examples)
predictions = estimator._get_predict_ops(features)
default_signature, named_graph_signatures = signature_fn(
examples, features, predictions)
if exports_to_keep is not None:
exports_to_keep = gc.largest_export_versions(exports_to_keep)
_export_graph(g, _get_saver(), checkpoint_path, export_dir,
default_graph_signature=default_signature,
named_graph_signatures=named_graph_signatures,
exports_to_keep=exports_to_keep)
开发者ID:363158858,项目名称:tensorflow,代码行数:33,代码来源:export.py
示例9: _restore_or_save_initial_ckpt
def _restore_or_save_initial_ckpt(self, session):
# Ideally this should be run in after_create_session but is not for the
# following reason:
# Currently there is no way of enforcing an order of running the
# `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
# is run *after* this hook. That is troublesome because
# 1. If a checkpoint exists and this hook restores it, the initializer hook
# will override it.
# 2. If no checkpoint exists, this hook will try to save an initialized
# iterator which will result in an exception.
#
# As a temporary fix we enter the following implicit contract between this
# hook and the _DatasetInitializerHook.
# 1. The _DatasetInitializerHook initializes the iterator in the call to
# after_create_session.
# 2. This hook saves the iterator on the first call to `before_run()`, which
# is guaranteed to happen after `after_create_session()` of all hooks
# have been run.
# Check if there is an existing checkpoint. If so, restore from it.
# pylint: disable=protected-access
latest_checkpoint_path = saver_lib.latest_checkpoint(
self._checkpoint_saver_hook._checkpoint_dir,
latest_filename=self._latest_filename)
if latest_checkpoint_path:
self._checkpoint_saver_hook._get_saver().restore(session,
latest_checkpoint_path)
else:
# The checkpoint saved here is the state at step "global_step".
# Note: We do not save the GraphDef or MetaGraphDef here.
global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
self._checkpoint_saver_hook._save(session, global_step)
self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
开发者ID:BhaskarNallani,项目名称:tensorflow,代码行数:33,代码来源:iterator_ops.py
示例10: _evaluate_model
def _evaluate_model(self,
input_fn,
steps,
feed_fn=None,
metrics=None,
name=''):
if self._config.execution_mode not in ('all', 'evaluate', 'eval_evalset'):
return
checkpoint_path = saver.latest_checkpoint(self._model_dir)
eval_dir = os.path.join(self._model_dir, 'eval' if not name else
'eval_' + name)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step = contrib_framework.create_global_step(g)
features, targets = input_fn()
self._check_inputs(features, targets)
eval_dict = self._get_eval_ops(features, targets,
metrics if metrics is not None else
self._get_default_metric_functions())
update_op, eval_dict = self._extract_metric_update_ops(eval_dict)
eval_results, _ = evaluate(graph=g,
output_dir=eval_dir,
checkpoint_path=checkpoint_path,
eval_dict=eval_dict,
update_op=update_op,
global_step_tensor=global_step,
supervisor_master=self._config.master,
feed_fn=feed_fn,
max_steps=steps)
return eval_results
开发者ID:01-,项目名称:tensorflow,代码行数:31,代码来源:estimator.py
示例11: _infer_model
def _infer_model(
self, input_fn, feed_fn=None, outputs=None, as_iterable=False):
# Check that model has been trained.
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
predictions = self._get_predict_ops(features)
# If predictions is single output - wrap it into dict, and remember to
# return not a dict.
return_dict = isinstance(predictions, dict)
if not return_dict:
predictions = {'predictions': predictions}
# Filter what to run predictions on, if outputs provided.
if outputs:
existing_keys = predictions.keys()
predictions = {
key: value for key, value in predictions.items() if key in outputs
}
if not predictions:
raise ValueError('Expected to run at least one output from %s, '
'provided %s.' % (existing_keys, outputs))
if as_iterable:
return self._infer_model_as_iterable(
checkpoint_path, predictions, feed_fn, return_dict)
else:
return self._infer_model_single(
checkpoint_path, predictions, feed_fn, return_dict)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:35,代码来源:estimator.py
示例12: _infer_model
def _infer_model(self,
x=None, input_fn=None, feed_fn=None,
batch_size=None, axis=None, proba=False):
# Converts inputs into tf.DataFrame / tf.Series.
batch_size = -1 if batch_size is None else batch_size
if x is not None:
input_fn, feed_fn = _get_predict_input_fn(x, batch_size)
checkpoint_path = saver.latest_checkpoint(self._model_dir)
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
contrib_framework.create_global_step(g)
features, _ = input_fn()
predictions = self._get_predict_ops(features)
if not isinstance(predictions, dict):
predictions = {'predictions': predictions}
# TODO(ipolosukhin): Support batching
if feed_fn is None:
return infer(checkpoint_path, predictions)
preds = {}
while True:
try:
feed_dict = feed_fn()
except StopIteration:
break
if feed_dict is None:
break
outputs = infer(checkpoint_path, predictions, feed_dict=feed_dict)
for key in outputs:
if key not in preds:
preds[key] = []
preds[key].append(outputs[key])
for key in preds:
preds[key] = np.concatenate(preds[key], axis=0)
return preds
开发者ID:01-,项目名称:tensorflow,代码行数:35,代码来源:estimator.py
示例13: testUsageGraph
def testUsageGraph(self):
"""Expected usage when graph building."""
with context.graph_mode():
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default():
network = MyNetwork()
optimizer = adam.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, network=network,
global_step=training_util.get_or_create_global_step())
input_value = constant_op.constant([[3.]])
train_op = optimizer.minimize(
network(input_value),
global_step=root.global_step)
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
with self.test_session(graph=ops.get_default_graph()) as session:
status = root.restore(save_path=checkpoint_path)
status.initialize_or_restore(session=session)
if checkpoint_path is None:
self.assertEqual(0, training_continuation)
with self.assertRaises(AssertionError):
status.assert_consumed()
else:
status.assert_consumed()
for _ in range(num_training_steps):
session.run(train_op)
root.save(file_prefix=checkpoint_prefix, session=session)
self.assertEqual((training_continuation + 1) * num_training_steps,
session.run(root.global_step))
self.assertEqual(training_continuation + 1,
session.run(root.save_counter))
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:34,代码来源:checkpointable_utils_test.py
示例14: create_session
def create_session(self, checkpoint_dir):
"""Creates a MonitoredSession for this predictor."""
checkpoint_path = saver.latest_checkpoint(checkpoint_dir)
return training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
config=self._session_config()))
开发者ID:ucam-smt,项目名称:sgnmt,代码行数:7,代码来源:tf_nizza.py
示例15: testDeferredRestorationUsageEager
def testDeferredRestorationUsageEager(self):
"""An idiomatic eager execution example."""
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
latest_object_graph = None # Will be saved with the checkpoint eventually.
for training_continuation in range(3):
with ops.Graph().as_default():
network = MyNetwork()
optimizer = CheckpointableAdam(0.001)
root = Root(optimizer=optimizer, network=network)
checkpointable.restore(
save_path=core_saver.latest_checkpoint(checkpoint_directory),
root_checkpointable=root,
object_graph_proto=latest_object_graph)
for _ in range(num_training_steps):
# TODO(allenl): Use a Dataset and serialize/checkpoint it.
input_value = constant_op.constant([[3.]])
optimizer.minimize(
lambda: network(input_value), # pylint: disable=cell-var-from-loop
global_step=root.global_step)
latest_object_graph, _ = checkpointable.save(
file_prefix=checkpoint_prefix,
root_checkpointable=root)
self.assertEqual((training_continuation + 1) * num_training_steps,
root.global_step.numpy())
开发者ID:japrogramer,项目名称:tensorflow,代码行数:26,代码来源:checkpointable_test.py
示例16: predict
def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None):
"""Returns predictions for given features.
Args:
input_fn: Input function returning features which is a dictionary of
string feature name to `Tensor` or `SparseTensor`. If it returns a
tuple, first item is extracted as features. Prediction continues until
`input_fn` raises an end-of-input exception (`OutOfRangeError` or
`StopIteration`).
predict_keys: list of `str`, name of the keys to predict. It is used if
the `EstimatorSpec.predictions` is a `dict`. If `predict_keys` is used
then rest of the predictions will be filtered from the dictionary. If
`None`, returns all.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the prediction call.
checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
latest checkpoint in `model_dir` is used.
Yields:
Evaluated values of `predictions` tensors.
Raises:
ValueError: Could not find a trained model in model_dir.
ValueError: if batch length of predictions are not same.
ValueError: If there is a conflict between `predict_keys` and
`predictions`. For example if `predict_keys` is not `None` but
`EstimatorSpec.predictions` is not a `dict`.
"""
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
if not checkpoint_path:
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise ValueError('Could not find trained model in model_dir: {}.'.format(
self._model_dir))
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
training.create_global_step(g)
features = self._get_features_from_input_fn(input_fn)
estimator_spec = self._call_model_fn(features, None,
model_fn_lib.ModeKeys.PREDICT)
predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
with training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
scaffold=estimator_spec.scaffold,
config=config_pb2.ConfigProto(allow_soft_placement=True)),
hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
preds_evaluated = mon_sess.run(predictions)
if not isinstance(predictions, dict):
for pred in preds_evaluated:
yield pred
else:
for i in range(self._extract_batch_length(preds_evaluated)):
yield {
key: value[i]
for key, value in six.iteritems(preds_evaluated)
}
开发者ID:LugarkPirog,项目名称:tensorflow,代码行数:60,代码来源:estimator.py
示例17: testAgnosticUsage
def testAgnosticUsage(self):
"""Graph/eager agnostic usage."""
# Does create garbage when executing eagerly due to ops.Graph() creation.
num_training_steps = 10
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
for training_continuation in range(3):
with ops.Graph().as_default(), self.test_session(
graph=ops.get_default_graph()):
network = MyNetwork()
optimizer = adam.AdamOptimizer(0.001)
root = checkpointable_utils.Checkpoint(
optimizer=optimizer, network=network,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
status = root.restore(save_path=checkpoint_path)
input_value = constant_op.constant([[3.]])
train_fn = functools.partial(
optimizer.minimize,
functools.partial(network, input_value),
global_step=root.global_step)
if context.in_graph_mode():
train_fn = functools.partial(self.evaluate, train_fn())
status.initialize_or_restore()
for _ in range(num_training_steps):
train_fn()
root.save(file_prefix=checkpoint_prefix)
self.assertEqual((training_continuation + 1) * num_training_steps,
self.evaluate(root.global_step))
self.assertEqual(training_continuation + 1,
self.evaluate(root.save_counter))
开发者ID:neuroradiology,项目名称:tensorflow,代码行数:31,代码来源:checkpointable_utils_test.py
示例18: create_session
def create_session(checkpoint_path, n_cpu_threads=-1):
"""Creates a MonitoredSession.
Args:
checkpoint_path (string): Path either to checkpoint directory or
directly to a checkpoint file.
n_cpu_threads (int): Number of CPU threads. If negative, we
assume either GPU decoding or that all
CPU cores can be used.
Returns:
A TensorFlow MonitoredSession.
"""
try:
if os.path.isdir(checkpoint_path):
checkpoint_path = saver.latest_checkpoint(checkpoint_path)
else:
logging.info("%s is not a directory. Interpreting as direct "
"path to checkpoint..." % checkpoint_path)
return training.MonitoredSession(
session_creator=training.ChiefSessionCreator(
checkpoint_filename_with_path=checkpoint_path,
config=session_config(n_cpu_threads)))
except tf.errors.NotFoundError as e:
logging.fatal("Could not find all variables of the computation "
"graph in the T2T checkpoint file. This means that the "
"checkpoint does not correspond to the model specified in "
"SGNMT. Please double-check pred_src_vocab_size, "
"pred_trg_vocab_size, and all the t2t_* parameters. "
"Also make sure that the checkpoint exists and is readable")
raise AttributeError("Could not initialize TF session.")
开发者ID:ucam-smt,项目名称:sgnmt,代码行数:30,代码来源:tf_utils.py
示例19: every_n_step_end
def every_n_step_end(self, step, outputs):
super(ValidationMonitor, self).every_n_step_end(step, outputs)
# TODO(mdan): The use of step below is probably misleading.
# The code should probably use the step from the checkpoint, because
# that's what is being evaluated.
if self._estimator is None:
raise ValueError("Missing call to set_estimator.")
current_time = time.time()
if (self._check_interval_secs is not None and
self._last_checkpoint_check_time is not None and
current_time - self._last_checkpoint_check_time <=
self._check_interval_secs):
logging.debug(
"Skipping evaluation since less than %d seconds have passed since "
"last check for a new checkpoint.", self._check_interval_secs)
return False
self._last_checkpoint_check_time = current_time
# Check that we are not running evaluation on the same checkpoint.
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
if latest_path is None:
logging.debug("Skipping evaluation since model has not been saved yet "
"at step %d.", step)
return False
if latest_path is not None and latest_path == self._latest_path:
logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
"as for step %d.", latest_path, step,
self._latest_path_step)
return False
self._latest_path = latest_path
self._latest_path_step = step
# Run evaluation and log it.
validation_outputs = self._evaluate_estimator()
stats = []
for name in validation_outputs:
stats.append("%s = %s" % (name, str(validation_outputs[name])))
logging.info("Validation (step %d): %s", step, ", ".join(stats))
# Early stopping logic.
if self.early_stopping_rounds is not None:
if self.early_stopping_metric not in validation_outputs:
raise ValueError("Metric %s missing from outputs %s." %
(self.early_stopping_metric,
set(validation_outputs.keys())))
current_value = validation_outputs[self.early_stopping_metric]
if (self._best_value is None or (self.early_stopping_metric_minimize and
(current_value < self._best_value)) or
(not self.early_stopping_metric_minimize and
(current_value > self._best_value))):
self._best_value = current_value
self._best_metrics = copy.deepcopy(validation_outputs)
self._best_value_step = step
stop_now = (step - self._best_value_step >= self.early_stopping_rounds)
if stop_now:
logging.info("Stopping. Best step: {} with {} = {}.".format(
self._best_value_step, self.early_stopping_metric,
self._best_value))
self._early_stopped = True
return True
return False
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:60,代码来源:monitors.py
示例20: export_fn
def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None):
"""Exports the given Estimator as a SavedModel.
Args:
estimator: the Estimator to export.
export_dir_base: A string containing a directory to write the exported
graph and checkpoints.
checkpoint_path: The checkpoint path to export. If None (the default),
the most recent checkpoint found within the model directory is chosen.
eval_result: placehold args matching the call signature of ExportStrategy.
Returns:
The string path to the exported directory.
"""
if not checkpoint_path:
# TODO(b/67425018): switch to
# checkpoint_path = estimator.latest_checkpoint()
# as soon as contrib is cleaned up and we can thus be sure that
# estimator is a tf.estimator.Estimator and not a
# tf.contrib.learn.Estimator
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
export_checkpoint_path, export_eval_result = best_model_selector.update(
checkpoint_path, eval_result)
if export_checkpoint_path and export_eval_result is not None:
checkpoint_base = os.path.basename(export_checkpoint_path)
export_dir = os.path.join(export_dir_base, checkpoint_base)
return best_model_export_strategy.export(
estimator, export_dir, export_checkpoint_path, export_eval_result)
else:
return ''
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:31,代码来源:saved_model_export_utils.py
注:本文中的tensorflow.python.training.saver.latest_checkpoint函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论