本文整理汇总了Python中tensorflow.python.platform.logging.warning函数的典型用法代码示例。如果您正苦于以下问题:Python warning函数的具体用法?Python warning怎么用?Python warning使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了warning函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _MaybeDeleteOldCheckpoints
def _MaybeDeleteOldCheckpoints(self, latest_save_path):
"""Deletes old checkpoints if necessary.
Always keep the last max_to_keep checkpoints. If
keep_checkpoint_every_n_hours was specified, keep an additional checkpoint
every N hours. For example, if N is 0.5, an additional checkpoint is kept
for every 0.5 hours of training; if N is 10, an additional checkpoint is
kept for every 10 hours of training.
Args:
latest_save_path: Name including path of checkpoint file to save.
"""
if not self._max_to_keep:
return
# Remove first from list if the same name was used before.
for p in self._last_checkpoints:
if latest_save_path == self._CheckpointFilename(p):
self._last_checkpoints.remove(p)
# Append new path to list
self._last_checkpoints.append((latest_save_path, time.time()))
# If more than max_to_keep, remove oldest.
if len(self._last_checkpoints) > self._max_to_keep:
p = self._last_checkpoints.pop(0)
# Do not delete the file if we keep_checkpoint_every_n_hours is set and we
# have reached N hours of training.
should_keep = p[1] > self._next_checkpoint_time
if should_keep:
self._next_checkpoint_time += self._keep_checkpoint_every_n_hours * 3600
return
# Otherwise delete the files.
for f in gfile.Glob(self._CheckpointFilename(p)):
try:
gfile.Remove(f)
except gfile.GOSError as e:
logging.warning("Ignoring: %s", str(e))
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:35,代码来源:saver.py
示例2: get_checkpoint_state
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(
checkpoint_dir, latest_filename)
f = None
try:
# Check that the file exists before opeining it to avoid
# many lines of errors from colossus in the logs.
if gfile.Exists(coord_checkpoint_filename):
f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
ckpt = CheckpointState()
text_format.Merge(f.read(), ckpt)
except gfile.FileError:
# It's ok if the file cannot be read
return None
except text_format.ParseError, e:
logging.warning(str(e))
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
开发者ID:ange3,项目名称:deepcode,代码行数:33,代码来源:saver.py
示例3: _default_global_step_tensor
def _default_global_step_tensor(self):
try:
gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
return gs
else:
logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
return None
except KeyError:
return None
开发者ID:Anandnitrate,项目名称:tensorflow,代码行数:10,代码来源:supervisor.py
示例4: start_standard_services
def start_standard_services(self, sess):
"""Start the standard services for 'sess'.
This starts services in the background. The services started depend
on the parameters to the constructor and may include:
- A Summary thread computing summaries every save_summaries_secs.
- A Checkpoint thread saving the model every every save_model_secs.
- A StepCounter thread measure step time.
Args:
sess: A Session.
Returns:
A list of threads that are running the standard services. You can use
the Supervisor's Coordinator to join these threads with:
sv.coord.Join(<list of threads>)
Raises:
RuntimeError: If called with a non-chief Supervisor.
ValueError: If not `logdir` was passed to the constructor as the
services need a log directory.
"""
if not self._is_chief:
raise RuntimeError("Only chief supervisor can start standard services. "
"Because only chief supervisors can write events.")
if not self._logdir:
logging.warning("Standard services need a 'logdir' "
"passed to the SessionManager")
return
if self._global_step is not None and self._summary_writer:
# Only add the session log if we keep track of global step.
# TensorBoard cannot use START message for purging expired events
# if there is no step value.
current_step = training_util.global_step(sess, self._global_step)
self._summary_writer.add_session_log(
SessionLog(status=SessionLog.START),
current_step)
threads = []
if self._save_summaries_secs and self._summary_writer:
if self._summary_op is not None:
threads.append(SVSummaryThread(self, sess))
if self._global_step is not None:
threads.append(SVStepCounterThread(self, sess))
if self.saver and self._save_model_secs:
threads.append(SVTimerCheckpointThread(self, sess))
for t in threads:
t.start()
self._started_threads.extend(threads)
return threads
开发者ID:2er0,项目名称:tensorflow,代码行数:54,代码来源:supervisor.py
示例5: main
def main(unused_argv=None):
if FLAGS.debug:
logging.set_verbosity(logging.DEBUG)
logging.info('TensorBoard is in debug mode.')
if not FLAGS.logdir:
logging.error('A logdir must be specified. Run `tensorboard --help` for '
'details and examples.')
return -1
logging.info('Starting TensorBoard in directory %s', os.getcwd())
path_to_run = ParseEventFilesFlag(FLAGS.logdir)
logging.info('TensorBoard path_to_run is: %s', path_to_run)
multiplexer = event_multiplexer.EventMultiplexer(
size_guidance=TENSORBOARD_SIZE_GUIDANCE)
# Ensure the Multiplexer initializes in a loaded state before it adds runs
# So it can handle HTTP requests while runs are loading
multiplexer.Reload()
def _Load():
start = time.time()
for (path, name) in six.iteritems(path_to_run):
multiplexer.AddRunsFromDirectory(path, name)
multiplexer.Reload()
duration = time.time() - start
logging.info('Multiplexer done loading. Load took %0.1f secs', duration)
t = threading.Timer(LOAD_INTERVAL, _Load)
t.daemon = True
t.start()
t = threading.Timer(0, _Load)
t.daemon = True
t.start()
factory = functools.partial(tensorboard_handler.TensorboardHandler,
multiplexer)
try:
server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
except socket.error:
logging.error('Tried to connect to port %d, but that address is in use.',
FLAGS.port)
return -2
try:
tag = resource_loader.load_resource('tensorboard/TAG').strip()
logging.info('TensorBoard is tag: %s', tag)
except IOError:
logging.warning('Unable to read TensorBoard tag')
tag = ''
status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
server.serve_forever()
开发者ID:chaabni,项目名称:tensorflow,代码行数:53,代码来源:tensorboard.py
示例6: update_checkpoint_state
def update_checkpoint_state(save_dir,
model_checkpoint_path,
all_model_checkpoint_paths=None,
latest_filename=None):
"""Updates the content of the 'checkpoint' file.
This updates the checkpoint file containing a CheckpointState
proto.
Args:
save_dir: Directory where the model was saved.
model_checkpoint_path: The checkpoint file.
all_model_checkpoint_paths: list of strings. Paths to all not-yet-deleted
checkpoints, sorted from oldest to newest. If this is a non-empty list,
the last element must be equal to model_checkpoint_path. These paths
are also saved in the CheckpointState proto.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Raises:
RuntimeError: If the save paths conflict.
"""
if all_model_checkpoint_paths is None:
all_model_checkpoint_paths = []
if all_model_checkpoint_paths and all_model_checkpoint_paths[-1] != model_checkpoint_path:
logging.warning(
"%s is not in all_model_checkpoint_paths! Manually adding it.",
model_checkpoint_path)
all_model_checkpoint_paths.append(model_checkpoint_path)
# Writes the "checkpoint" file for the coordinator for later restoration.
coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
# Relative paths need to be rewritten to be relative to the "save_dir".
if not os.path.isabs(model_checkpoint_path):
model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
all_model_checkpoint_paths = [
os.path.relpath(p, save_dir) for p in all_model_checkpoint_paths
if not os.path.isabs(p)
]
if coord_checkpoint_filename == model_checkpoint_path:
raise RuntimeError("Save path '%s' conflicts with path used for "
"checkpoint state. Please use a different save path." %
model_checkpoint_path)
coord_checkpoint_proto = CheckpointState(
model_checkpoint_path=model_checkpoint_path,
all_model_checkpoint_paths=all_model_checkpoint_paths)
f = gfile.FastGFile(coord_checkpoint_filename, mode="w")
f.write(text_format.MessageToString(coord_checkpoint_proto))
f.close()
开发者ID:JesseLivezey,项目名称:tensorflow,代码行数:52,代码来源:saver.py
示例7: add_graph
def add_graph(self, graph, global_step=None, graph_def=None):
"""Adds a `Graph` to the event file.
The graph described by the protocol buffer will be displayed by
TensorBoard. Most users pass a graph in the constructor instead.
Args:
graph: A `Graph` object, such as `sess.graph`.
global_step: Number. Optional global step counter to record with the
graph.
graph_def: DEPRECATED. Use the `graph` parameter instead.
Raises:
ValueError: If both graph and graph_def are passed to the method.
"""
if graph is not None and graph_def is not None:
raise ValueError("Please pass only graph, or graph_def (deprecated), "
"but not both.")
if isinstance(graph, ops.Graph) or isinstance(graph_def, ops.Graph):
# The user passed a `Graph`.
# Check if the user passed it via the graph or the graph_def argument and
# correct for that.
if not isinstance(graph, ops.Graph):
logging.warning("When passing a `Graph` object, please use the `graph`"
" named argument instead of `graph_def`.")
graph = graph_def
# Serialize the graph with additional info.
true_graph_def = graph.as_graph_def(add_shapes=True)
elif (isinstance(graph, graph_pb2.GraphDef)
or isinstance(graph_def, graph_pb2.GraphDef)):
# The user passed a `GraphDef`.
logging.warning("Passing a `GraphDef` to the SummaryWriter is deprecated."
" Pass a `Graph` object instead, such as `sess.graph`.")
# Check if the user passed it via the graph or the graph_def argument and
# correct for that.
if isinstance(graph, graph_pb2.GraphDef):
true_graph_def = graph
else:
true_graph_def = graph_def
else:
# The user passed neither `Graph`, nor `GraphDef`.
raise TypeError("The passed graph must be an instance of `Graph` "
"or the deprecated `GraphDef`")
# Finally, add the graph_def to the summary writer.
self._add_graph_def(true_graph_def, global_step)
开发者ID:MPesin,项目名称:tensorflow,代码行数:51,代码来源:summary_io.py
示例8: _default_global_step_tensor
def _default_global_step_tensor(self):
"""Returns the global_step from the default graph.
Returns:
The global step `Tensor` or `None`.
"""
try:
gs = ops.get_default_graph().get_tensor_by_name("global_step:0")
if gs.dtype.base_dtype in [dtypes.int32, dtypes.int64]:
return gs
else:
logging.warning("Found 'global_step' is not an int type: %s", gs.dtype)
return None
except KeyError:
return None
开发者ID:2er0,项目名称:tensorflow,代码行数:15,代码来源:supervisor.py
示例9: _MakeShape
def _MakeShape(v, arg_name):
"""Convert v into a TensorShapeProto."""
# Args:
# v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
# arg_name: String, for error messages.
# Returns:
# A TensorShapeProto.
if isinstance(v, tensor_shape_pb2.TensorShapeProto):
for d in v.dim:
if d.name:
logging.warning("Warning: TensorShapeProto with a named dimension: %s",
str(v))
break
return v
return tensor_shape.as_shape(v).as_proto()
开发者ID:2er0,项目名称:tensorflow,代码行数:16,代码来源:op_def_library.py
示例10: get_checkpoint_state
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(
checkpoint_dir, latest_filename)
f = None
try:
# Check that the file exists before opening it to avoid
# many lines of errors from colossus in the logs.
if gfile.Exists(coord_checkpoint_filename):
f = gfile.FastGFile(coord_checkpoint_filename, mode="r")
ckpt = CheckpointState()
text_format.Merge(f.read(), ckpt)
# For relative model_checkpoint_path and all_model_checkpoint_paths,
# prepend checkpoint_dir.
if not os.path.isabs(checkpoint_dir):
if not os.path.isabs(ckpt.model_checkpoint_path):
ckpt.model_checkpoint_path = os.path.join(
checkpoint_dir, ckpt.model_checkpoint_path)
for i in range(len(ckpt.all_model_checkpoint_paths)):
p = ckpt.all_model_checkpoint_paths[i]
if not os.path.isabs(p):
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
except IOError:
# It's ok if the file cannot be read
return None
except text_format.ParseError as e:
logging.warning(str(e))
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
return None
finally:
if f:
f.close()
return ckpt
开发者ID:hessenh,项目名称:Human-Activity-Recognition,代码行数:47,代码来源:saver.py
示例11: main
def main(unused_argv=None):
if FLAGS.debug:
logging.set_verbosity(logging.DEBUG)
logging.info('TensorBoard is in debug mode.')
if not FLAGS.logdir:
msg = ('A logdir must be specified. Run `tensorboard --help` for '
'details and examples.')
logging.error(msg)
print(msg)
return -1
logging.info('Starting TensorBoard in directory %s', os.getcwd())
path_to_run = server.ParseEventFilesSpec(FLAGS.logdir)
logging.info('TensorBoard path_to_run is: %s', path_to_run)
multiplexer = event_multiplexer.EventMultiplexer(
size_guidance=server.TENSORBOARD_SIZE_GUIDANCE,
purge_orphaned_data=FLAGS.purge_orphaned_data)
server.StartMultiplexerReloadingThread(multiplexer, path_to_run,
FLAGS.reload_interval)
try:
tb_server = server.BuildServer(multiplexer, FLAGS.host, FLAGS.port)
except socket.error:
if FLAGS.port == 0:
msg = 'Unable to find any open ports.'
logging.error(msg)
print(msg)
return -2
else:
msg = 'Tried to connect to port %d, but address is in use.' % FLAGS.port
logging.error(msg)
print(msg)
return -3
try:
tag = resource_loader.load_resource('tensorboard/TAG').strip()
logging.info('TensorBoard is tag: %s', tag)
except IOError:
logging.warning('Unable to read TensorBoard tag')
tag = ''
status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
print('(You can navigate to http://%s:%d)' % (FLAGS.host, FLAGS.port))
tb_server.serve_forever()
开发者ID:2er0,项目名称:tensorflow,代码行数:46,代码来源:tensorboard.py
示例12: _add_collection_def
def _add_collection_def(meta_graph_def, key):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
"""
if not isinstance(key, six.string_types) and not isinstance(key, bytes):
logging.warning("Only collections with string type keys will be "
"serialized. This key has %s" % type(key))
return
collection_list = ops.get_collection(key)
if not collection_list:
return
try:
col_def = meta_graph_def.collection_def[key]
to_proto = ops.get_to_proto_function(key)
proto_type = ops.get_collection_proto_type(key)
if to_proto:
kind = "bytes_list"
for x in collection_list:
# Additional type check to make sure the returned proto is indeed
# what we expect.
proto = to_proto(x)
assert isinstance(proto, proto_type)
getattr(col_def, kind).value.append(proto.SerializeToString())
else:
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
getattr(col_def, kind).value.extend([x.name for x in collection_list])
elif kind == "bytes_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python3 distinguishes between bytes and strings.
getattr(col_def, kind).value.extend(
[compat.as_bytes(x) for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception as e: # pylint: disable=broad-except
logging.warning("Error encountered when serializing %s.\n"
"Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s" % (key, str(e)))
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return
开发者ID:AboorvaDevarajan,项目名称:tensorflow,代码行数:44,代码来源:saver.py
示例13: _show_compute
def _show_compute(self, show_dataflow):
"""Visualize the computation activity."""
for dev_stats in self._step_stats.dev_stats:
device_pid = self._device_pids[dev_stats.device]
for node_stats in dev_stats.node_stats:
tid = node_stats.thread_id
start_time = node_stats.all_start_micros
end_time = node_stats.all_start_micros + node_stats.all_end_rel_micros
_, _, inputs = self._parse_op_label(node_stats.timeline_label)
self._emit_op(node_stats, device_pid)
for input_name in inputs:
if input_name not in self._tensors:
# This can happen when partitioning has inserted a Send/Recv.
# We remove the numeric suffix so that the dataflow appears to
# come from the original node. Ideally, the StepStats would
# contain logging for the Send and Recv nodes.
index = input_name.rfind('/_')
if index > 0:
input_name = input_name[:index]
if input_name in self._tensors:
tensor = self._tensors[input_name]
tensor.add_ref(start_time)
tensor.add_unref(end_time - 1)
if show_dataflow:
# We use a different flow ID for every graph edge.
create_time, create_pid, create_tid = self._flow_starts[
input_name]
# Don't add flows when producer and consumer ops are on the same
# pid/tid since the horizontal arrows clutter the visualization.
if create_pid != device_pid or create_tid != tid:
flow_id = self._alloc_flow_id()
self._chrome_trace.emit_flow_start(input_name, create_time,
create_pid, create_tid,
flow_id)
self._chrome_trace.emit_flow_end(input_name, start_time,
device_pid, tid, flow_id)
else:
logging.warning('Can\'t find tensor %s', input_name)
开发者ID:6779660,项目名称:tensorflow,代码行数:43,代码来源:timeline.py
示例14: load_resource
def load_resource(path):
"""Load the resource at given path, where path is relative to tensorflow/.
Args:
path: a string resource path relative to tensorflow/.
Returns:
The contents of that resource.
Raises:
IOError: If the path is not found, or the resource can't be opened.
"""
path = os.path.join('tensorflow', path)
path = os.path.abspath(path)
try:
with open(path, 'rb') as f:
return f.read()
except IOError as e:
logging.warning('IOError %s on path %s', e, path)
开发者ID:adeelzaman,项目名称:tensorflow,代码行数:19,代码来源:_resource_loader.py
示例15: AddRun
def AddRun(self, path, name=None):
"""Add a run to the multiplexer.
If the name is not specified, it is the same as the path.
If a run by that name exists, and we are already watching the right path,
do nothing. If we are watching a different path, replace the event
accumulator.
If `AutoUpdate` or `Reload` have been called, it will `AutoUpdate` or
`Reload` the newly created accumulators. This maintains the invariant that
once the Multiplexer was activated, all of its accumulators are active.
Args:
path: Path to the event files (or event directory) for given run.
name: Name of the run to add. If not provided, is set to path.
Returns:
The `EventMultiplexer`.
"""
if name is None or name is '':
name = path
accumulator = None
with self._accumulators_mutex:
if name not in self._accumulators or self._paths[name] != path:
if name in self._paths and self._paths[name] != path:
# TODO(danmane) - Make it impossible to overwrite an old path with
# a new path (just give the new path a distinct name)
logging.warning('Conflict for name %s: old path %s, new path %s',
name, self._paths[name], path)
logging.info('Constructing EventAccumulator for %s', path)
accumulator = event_accumulator.EventAccumulator(path,
self._size_guidance)
self._accumulators[name] = accumulator
self._paths[name] = path
if accumulator:
if self._reload_called:
accumulator.Reload()
if self._autoupdate_called:
accumulator.AutoUpdate(self._autoupdate_interval)
return self
开发者ID:adam-erickson,项目名称:tensorflow,代码行数:41,代码来源:event_multiplexer.py
示例16: _model_not_ready
def _model_not_ready(self, sess):
"""Checks if the model is ready or not.
Args:
sess: A `Session`.
Returns:
`None` if the model is ready, a `String` with the reason why it is not
ready otherwise.
"""
if self._ready_op is None:
return None
else:
try:
sess.run(self._ready_op)
return None
except errors.FailedPreconditionError as e:
if "uninitialized" not in str(e):
logging.warning("Model not ready raised: %s", str(e))
raise e
return str(e)
开发者ID:2php,项目名称:tensorflow,代码行数:21,代码来源:session_manager.py
示例17: main
def main(unused_argv=None):
if FLAGS.debug:
logging.set_verbosity(logging.DEBUG)
logging.info('TensorBoard is in debug mode.')
if not FLAGS.logdir:
logging.error('A logdir must be specified. Run `tensorboard --help` for '
'details and examples.')
return -1
if FLAGS.debug:
logging.info('Starting TensorBoard in directory %s', os.getcwd())
path_to_run = ParseEventFilesFlag(FLAGS.logdir)
multiplexer = event_multiplexer.AutoloadingMultiplexer(
path_to_run=path_to_run, interval_secs=60,
size_guidance=TENSORBOARD_SIZE_GUIDANCE)
multiplexer.AutoUpdate(interval=30)
factory = functools.partial(tensorboard_handler.TensorboardHandler,
multiplexer)
try:
server = ThreadedHTTPServer((FLAGS.host, FLAGS.port), factory)
except socket.error:
logging.error('Tried to connect to port %d, but that address is in use.',
FLAGS.port)
return -2
try:
tag = resource_loader.load_resource('tensorboard/TAG').strip()
logging.info('TensorBoard is tag: %s', tag)
except IOError:
logging.warning('Unable to read TensorBoard tag')
tag = ''
status_bar.SetupStatusBarInsideGoogle('TensorBoard %s' % tag, FLAGS.port)
print('Starting TensorBoard %s on port %d' % (tag, FLAGS.port))
print('(You can navigate to http://localhost:%d)' % FLAGS.port)
server.serve_forever()
开发者ID:bgyss,项目名称:tensorflow,代码行数:39,代码来源:tensorboard.py
示例18: _MaybeDeleteOldCheckpoints
def _MaybeDeleteOldCheckpoints(self, latest_save_path,
meta_graph_suffix="meta"):
"""Deletes old checkpoints if necessary.
Always keep the last `max_to_keep` checkpoints. If
`keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
kept for every 0.5 hours of training; if `N` is 10, an additional
checkpoint is kept for every 10 hours of training.
Args:
latest_save_path: Name including path of checkpoint file to save.
meta_graph_suffix: Suffix for MetaGraphDef file. Defaults to 'meta'.
"""
if not self.saver_def.max_to_keep:
return
# Remove first from list if the same name was used before.
for p in self._last_checkpoints:
if latest_save_path == self._CheckpointFilename(p):
self._last_checkpoints.remove(p)
# Append new path to list
self._last_checkpoints.append((latest_save_path, time.time()))
# If more than max_to_keep, remove oldest.
if len(self._last_checkpoints) > self.saver_def.max_to_keep:
p = self._last_checkpoints.pop(0)
# Do not delete the file if we keep_checkpoint_every_n_hours is set and we
# have reached N hours of training.
should_keep = p[1] > self._next_checkpoint_time
if should_keep:
self._next_checkpoint_time += (
self.saver_def.keep_checkpoint_every_n_hours * 3600)
return
# Otherwise delete the files.
for f in gfile.Glob(self._CheckpointFilename(p)):
try:
gfile.Remove(f)
gfile.Remove(".".join([f, meta_graph_suffix]))
except OSError as e:
logging.warning("Ignoring: %s", str(e))
开发者ID:hdzz,项目名称:tensorflow,代码行数:39,代码来源:saver.py
示例19: _add_collection_def
def _add_collection_def(meta_graph_def, key):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
"""
if not isinstance(key, (str, bytes, unicode)):
logging.warning("Only collections with string type keys will be "
"serialized. This key has %s" % type(key))
return
collection_list = ops.get_collection(key)
if not collection_list:
return
try:
col_def = meta_graph_def.collection_def[key]
to_proto = ops.get_to_proto_function(key)
proto_type = ops.get_collection_proto_type(key)
if to_proto:
kind = "bytes_list"
for x in collection_list:
# Additional type check to make sure the returned proto is indeed
# what we expect.
proto = to_proto(x)
assert isinstance(proto, proto_type)
getattr(col_def, kind).value.append(proto.SerializeToString())
else:
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
getattr(col_def, kind).value.extend([x.name for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception, e: # pylint: disable=broad-except
logging.warning("Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s" % str(e))
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return
开发者ID:hdzz,项目名称:tensorflow,代码行数:38,代码来源:saver.py
示例20: replica_device_setter
def replica_device_setter(ps_tasks=0, ps_device="/job:ps",
worker_device="/job:worker", merge_devices=True,
cluster=None, ps_ops=None):
"""Return a `device function` to use when building a Graph for replicas.
Device Functions are used in `with tf.device(device_function):` statement to
automatically assign devices to `Operation` objects as they are constructed,
Device constraints are added from the inner-most context first, working
outwards. The merging behavior adds constraints to fields that are yet unset
by a more inner context. Currently the fields are (job, task, cpu/gpu).
If `cluster` is `None`, and `ps_tasks` is 0, the returned function is a no-op.
For example,
```python
# To build a cluster with two ps jobs on hosts ps0 and ps1, and 3 worker
# jobs on hosts worker0, worker1 and worker2.
cluster_spec = {
"ps": ["ps0:2222", "ps1:2222"],
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]}
with tf.device(tf.replica_device_setter(cluster=cluster_spec)):
# Build your graph
v1 = tf.Variable(...) # assigned to /job:ps/task:0
v2 = tf.Variable(...) # assigned to /job:ps/task:1
v3 = tf.Variable(...) # assigned to /job:ps/task:0
# Run compute
```
Args:
ps_tasks: Number of tasks in the `ps` job.
ps_device: String. Device of the `ps` job. If empty no `ps` job is used.
Defaults to `ps`.
worker_device: String. Device of the `worker` job. If empty no `worker`
job is used.
merge_devices: `Boolean`. If `True`, merges or only sets a device if the
device constraint is completely unset. merges device specification rather
than overriding them.
cluster: `ClusterDef` proto or `ClusterSpec`.
ps_ops: List of `Operation` objects that need to be placed on `ps` devices.
Returns:
A function to pass to `tf.device()`.
Raises:
TypeError if `cluster` is not a dictionary or `ClusterDef` protocol buffer.
"""
if cluster is not None:
if isinstance(cluster, server_lib.ClusterSpec):
cluster_spec = cluster.as_cluster_spec()
else:
cluster_spec = server_lib.ClusterSpec(cluster).as_cluster_spec()
# Get ps_job_name from ps_device by striping "/job:".
ps_job_name = ps_device.lstrip("/job:")
if ps_job_name not in cluster_spec or cluster_spec[ps_job_name] is None:
return None
ps_tasks = len(cluster_spec[ps_job_name])
if ps_tasks == 0:
return None
else:
if not merge_devices:
logging.warning(
"DEPRECATION: It is recommended to set merge_devices=true in "
"replica_device_setter")
chooser = _ReplicaDeviceChooser(
ps_tasks, ps_device, worker_device, merge_devices, ps_ops)
return chooser.device_function
开发者ID:4chin,项目名称:tensorflow,代码行数:68,代码来源:device_setter.py
注:本文中的tensorflow.python.platform.logging.warning函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论