本文整理汇总了Python中tensorflow.python.training.saver.import_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python import_meta_graph函数的具体用法?Python import_meta_graph怎么用?Python import_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了import_meta_graph函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _read_vars
def _read_vars(self, model_dir):
"""Returns (global_step, latest_feature)."""
with ops.Graph().as_default() as g:
ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
meta_filename = ckpt_path + '.meta'
saver_lib.import_meta_graph(meta_filename)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
saver.restore(sess, ckpt_path)
return sess.run(ops.get_collection('my_vars'))
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:10,代码来源:iterator_ops_test.py
示例2: testMetaGraphSaveLoad
def testMetaGraphSaveLoad(self):
save_prefix = os.path.join(self.get_temp_dir(), "ckpt")
save_graph = ops.Graph()
with save_graph.as_default(), self.test_session(
graph=save_graph) as session:
partitioner = partitioned_variables.fixed_size_partitioner(5, axis=0)
with variable_scope.variable_scope("root", partitioner=partitioner):
v0 = variable_scope.get_variable(
"v0", dtype=dtypes.float32, shape=(10, 10))
v0_list = v0._get_variable_list()
v0_part = v0._get_partitions()
self.assertEqual(len(v0_list), 5)
self.assertAllEqual(v0_part, (5, 1))
variables.global_variables_initializer().run()
save_graph.get_collection_ref("partvar").append(v0)
saver = saver_lib.Saver()
save_graph.finalize()
save_path = saver.save(sess=session, save_path=save_prefix)
previous_value = session.run(
save_graph.get_tensor_by_name(v0.name + ":0"))
restore_graph = ops.Graph()
with restore_graph.as_default(), self.test_session(
graph=restore_graph) as session:
saver = saver_lib.import_meta_graph(save_path + ".meta")
saver.restore(sess=session, save_path=save_path)
v0, = save_graph.get_collection_ref("partvar")
self.assertIsInstance(v0, variables.PartitionedVariable)
self.assertAllEqual(
previous_value,
session.run(restore_graph.get_tensor_by_name(v0.name + ":0")))
开发者ID:AnishShah,项目名称:tensorflow,代码行数:32,代码来源:partitioned_variables_test.py
示例3: graph_def_from_checkpoint
def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
"""Converts checkpoint data to GraphDef.
Reads the latest checkpoint data and produces a GraphDef in which the
variables have been converted to constants.
Args:
checkpoint_dir: Path to the checkpoints.
output_node_names: List of name strings for the result nodes of the graph.
Returns:
A GraphDef from the latest checkpoint
Raises:
ValueError: if no checkpoint is found
"""
checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
if checkpoint_path is None:
raise ValueError('Could not find a checkpoint at: {0}.'
.format(checkpoint_dir))
saver_for_restore = saver_lib.import_meta_graph(
checkpoint_path + '.meta', clear_devices=True)
with session.Session() as sess:
saver_for_restore.restore(sess, checkpoint_path)
graph_def = ops.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(
sess, graph_def, output_node_names)
return output_graph_def
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:strip_pruning_vars_lib.py
示例4: _ExportAndImportGraph
def _ExportAndImportGraph(self, graph):
"""Export and import graph into a new graph."""
meta_graph = saver_lib.export_meta_graph(
graph=graph, collection_list=graph.get_all_collection_keys())
graph_copy = ops.Graph()
with graph_copy.as_default():
_ = saver_lib.import_meta_graph(meta_graph)
return graph_copy
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:moving_averages_test.py
示例5: _CopyGraph
def _CopyGraph(self, graph):
"""Return a copy of graph."""
meta_graph = saver_lib.export_meta_graph(
graph=graph, collection_list=graph.get_all_collection_keys())
graph_copy = ops.Graph()
with graph_copy.as_default():
_ = saver_lib.import_meta_graph(meta_graph)
return graph_copy
开发者ID:Eagle732,项目名称:tensorflow,代码行数:8,代码来源:fold_batch_norms_test.py
示例6: load
def load(sess, tags, export_dir):
"""Loads the model from a SavedModel as specified by tags.
Args:
sess: The TensorFlow session to restore the variables.
tags: Set of string tags to identify the required MetaGraphDef. These should
correspond to the tags used when saving the variables using the
SavedModel `save()` API.
export_dir: Directory in which the SavedModel protocol buffer and variables
to be loaded are located.
Returns:
The `MetaGraphDef` protocol buffer loaded in the provided session. This
can be used to further extract signature-defs, collection-defs, etc.
Raises:
RuntimeError: MetaGraphDef associated with the tags cannot be found.
"""
# Build the SavedModel protocol buffer and find the requested meta graph def.
saved_model = _parse_saved_model(export_dir)
found_match = False
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set(tags):
meta_graph_def_to_load = meta_graph_def
found_match = True
break
if not found_match:
raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
"[]") + " could not be found in SavedModel")
# Build a saver by importing the meta graph def to load.
saver = tf_saver.import_meta_graph(meta_graph_def_to_load)
# Build the checkpoint path where the variables are located.
variables_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.VARIABLES_DIRECTORY),
compat.as_bytes(constants.VARIABLES_FILENAME))
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)
# Get asset tensors, if any.
asset_tensors_dictionary = _get_asset_tensors(export_dir,
meta_graph_def_to_load)
main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
else:
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
if legacy_init_op_tensor is not None:
sess.run(fetches=[legacy_init_op_tensor],
feed_dict=asset_tensors_dictionary)
return meta_graph_def_to_load
开发者ID:curtiszimmerman,项目名称:tensorflow,代码行数:57,代码来源:loader.py
示例7: testMetagraph
def testMetagraph(self):
with ops.Graph().as_default():
with variable_scope.variable_scope("foo", use_resource=True):
a = variable_scope.get_variable("a", initializer=10.0)
momentum.MomentumOptimizer(
learning_rate=0.001, momentum=0.1).minimize(
a,
colocate_gradients_with_ops=True,
global_step=training_util.get_or_create_global_step())
graph = ops.get_default_graph()
meta_graph_def = saver.export_meta_graph(graph=graph)
with ops.Graph().as_default():
saver.import_meta_graph(meta_graph_def, import_scope="")
meta_graph_two = saver.export_meta_graph(graph=graph)
self.assertEqual(meta_graph_def, meta_graph_two)
开发者ID:aeverall,项目名称:tensorflow,代码行数:18,代码来源:resource_variable_ops_test.py
示例8: testGradientOfDeserializedCond
def testGradientOfDeserializedCond(self):
with ops.Graph().as_default():
pred = array_ops.placeholder(dtypes.bool, name="pred")
x = constant_op.constant(3.0, name="x")
ops.add_to_collection("x", x)
def true_fn():
return math_ops.pow(x, 3)
def false_fn():
return x
ops.add_to_collection("pred", pred)
cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
for c in cond:
ops.add_to_collection("cond", c)
meta_graph = saver.export_meta_graph()
with ops.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
saver.import_meta_graph(meta_graph)
x = ops.get_collection("x")[0]
pred = ops.get_collection("pred")[0]
cond = ops.get_collection("cond")
cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
cond_grad_grad = gradients_impl.gradients(
cond_grad, [x], name="cond_grad_grad")
# d[x^3]/dx = 3x^2
true_val = sess.run(cond_grad, {pred: True})
self.assertEqual(true_val, [27.0])
# d[x]/dx = 1
false_val = sess.run(cond_grad, {pred: False})
self.assertEqual(false_val, [1.0])
true_val = sess.run(cond_grad_grad, {pred: True})
# d2[x^3]/dx2 = 6x
self.assertEqual(true_val, [18.0])
false_val = sess.run(cond_grad_grad, {pred: False})
# d2[x]/dx2 = 0
self.assertEqual(false_val, [0.0])
开发者ID:clsung,项目名称:tensorflow,代码行数:40,代码来源:cond_v2_test.py
示例9: _get_default_signature
def _get_default_signature(self, export_meta_filename):
""" Gets the default signature from the export.meta file. """
with session.Session():
save = saver.import_meta_graph(export_meta_filename)
meta_graph_def = save.export_meta_graph()
collection_def = meta_graph_def.collection_def
signatures_any = collection_def['serving_signatures'].any_list.value
self.assertEquals(len(signatures_any), 1)
signatures = manifest_pb2.Signatures()
signatures_any[0].Unpack(signatures)
default_signature = signatures.default_signature
return default_signature
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:13,代码来源:export_test.py
示例10: testNoVariables
def testNoVariables(self):
test_dir = _TestDir("no_variables")
filename = os.path.join(test_dir, "metafile")
input_feed_value = -10 # Arbitrary input value for feed_dict.
orig_graph = tf.Graph()
with self.test_session(graph=orig_graph) as sess:
# Create a minimal graph with zero variables.
input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
offset = tf.constant(42, dtype=tf.float32, name="offset")
output_tensor = tf.add(input_tensor, offset, name="add_offset")
# Add input and output tensors to graph collections.
tf.add_to_collection("input_tensor", input_tensor)
tf.add_to_collection("output_tensor", output_tensor)
output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
self.assertEqual(output_value, 32)
# Generates MetaGraphDef.
#
# Note that this is calling the saver *module-level* export_meta_graph and
# not the Saver.export_meta_graph instance-level method.
meta_graph_def = saver_module.export_meta_graph(
filename=filename,
graph_def=tf.get_default_graph().as_graph_def(),
collection_list=["input_tensor", "output_tensor"],
saver_def=None,
)
# Create a clean graph and import the MetaGraphDef nodes.
new_graph = tf.Graph()
with self.test_session(graph=new_graph) as sess:
# Import the previously export meta graph.
saver_instance = saver_module.import_meta_graph(filename)
# The saver instance should be None since there are no graph variables
# to be restored in this case.
self.assertIsNone(saver_instance)
# Re-exports the current graph state for comparison to the original.
new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
# Ensures that we can still get a reference to our graph collections.
new_input_tensor = tf.get_collection("input_tensor")[0]
new_output_tensor = tf.get_collection("output_tensor")[0]
# Verifies that the new graph computes the same result as the original.
new_output_value = sess.run(
new_output_tensor, {new_input_tensor: input_feed_value})
self.assertEqual(new_output_value, output_value)
开发者ID:2er0,项目名称:tensorflow,代码行数:51,代码来源:saver_test.py
示例11: testRestoreFromMetaGraph
def testRestoreFromMetaGraph(self):
logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
variables.VariableV1(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
filename = sv.saver.save(sess, sv.save_path)
sv.stop()
# Create a new Graph and Supervisor and recover.
with ops.Graph().as_default():
new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"]))
self.assertIsNotNone(new_saver)
sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
sess = sv2.prepare_or_wait_for_session("")
self.assertEquals(1, sess.run("v0:0"))
sv2.saver.save(sess, sv2.save_path)
sv2.stop()
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:17,代码来源:supervisor_test.py
示例12: _testSaveRestoreUtility
def _testSaveRestoreUtility(self, start, break_range, stop):
path = self._iterator_checkpoint_prefix()
step = 0
meta_filename = path + "-%d.meta" % step
input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile(
np.array([[12], [13], [14], [15]]), 4))
to_concatenate_components = (np.tile(
np.array([[5], [6], [7], [8], [9]]), 20), np.tile(
np.array([[16], [17], [18], [19], [20]]), 15))
with ops.Graph().as_default() as g:
init_op, get_next = self._build_graph(input_components,
to_concatenate_components)
saver = saver_lib.Saver()
with self.test_session(graph=g) as sess:
sess.run(init_op)
for i in range(start, break_range):
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)
else:
for component, result_component in zip(to_concatenate_components,
result):
self.assertAllEqual(component[i - 4], result_component)
saver.save(sess, path, step)
with ops.Graph().as_default() as g:
saver = saver_lib.import_meta_graph(meta_filename)
with self.test_session(graph=g) as sess:
get_next = nest.pack_sequence_as(("a", "b"),
ops.get_collection("get_next"))
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for i in range(break_range, stop):
result = sess.run(get_next)
if i < 4:
for component, result_component in zip(input_components, result):
self.assertAllEqual(component[i], result_component)
else:
for component, result_component in zip(to_concatenate_components,
result):
self.assertAllEqual(component[i - 4], result_component)
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:45,代码来源:concatenate_dataset_op_test.py
示例13: testSaveRestoreUsingSaverFromMetaGraph
def testSaveRestoreUsingSaverFromMetaGraph(self):
def _build_graph(start, stop):
iterator = dataset_ops.Dataset.range(start,
stop).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
ops.add_to_collection("iterator_ops", init_op)
ops.add_to_collection("iterator_ops", get_next)
saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator)
# Add the SaveableObject to the `SAVEABLE_OBJECTS` collection
# so that it can be automatically picked up by the Saver.
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
saver = saver_lib.Saver()
return init_op, get_next, saver
start = 2
stop = 10
break_point = 5
path = self._iterator_checkpoint_prefix()
meta_filename = path + ".meta"
# Execute input pipeline for a few steps and save iterator state.
with ops.Graph().as_default() as g:
init_op, get_next, saver = _build_graph(start, stop)
with self.test_session(graph=g) as sess:
sess.run(variables.global_variables_initializer())
sess.run(init_op)
for i in range(start, break_point):
self.assertEqual(i, sess.run(get_next))
saver.save(sess, path)
# Build the saver from the MetaGraph using import_meta_graph and
# check that the iterator state is restored.
with ops.Graph().as_default() as g:
saver = saver_lib.import_meta_graph(meta_filename)
init_op, get_next = ops.get_collection("iterator_ops")
with self.test_session(graph=g) as sess:
saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir()))
for i in range(break_point, stop):
self.assertEqual(i, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
开发者ID:SylChan,项目名称:tensorflow,代码行数:43,代码来源:range_dataset_op_test.py
示例14: freeze_graph_with_def_protos
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
if input_meta_graph_def:
for node in input_meta_graph_def.graph_def.node:
node.device = ""
elif input_graph_def:
for node in input_graph_def.node:
node.device = ""
if input_graph_def:
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(
saver_def=input_saver_def, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
input_meta_graph_def, clear_devices=True)
restorer.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
elif input_saved_model_dir:
if saved_model_tags is None:
saved_model_tags = []
loader.load(sess, saved_model_tags, input_saved_model_dir)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
# List of all partition variables. Because the condition is heuristic
# based, the list could include false positives.
all_parition_variable_names = [
tensor.name.split(":")[0]
for op in sess.graph.get_operations()
for tensor in op.values()
if re.search(r"/part_\d+/", tensor.name)
]
has_partition_var = False
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
if any(key in name for name in all_parition_variable_names):
has_partition_var = True
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
try:
saver = saver_lib.Saver(
var_list=var_list, write_version=checkpoint_version)
except TypeError as e:
# `var_list` is required to be a map of variable names to Variable
# tensors. Partition variables are Identity tensors that cannot be
# handled by Saver.
if has_partition_var:
print("Models containing partition variables cannot be converted "
"from checkpoint files. Please pass in a SavedModel using "
"the flag --input_saved_model_dir.")
return -1
else:
raise e
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
#.........这里部分代码省略.........
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:101,代码来源:freeze_graph.py
示例15: _import_meta_graph
def _import_meta_graph(self):
meta_file_path = self._ckpt_path() + ".meta"
return saver_lib.import_meta_graph(meta_file_path)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:3,代码来源:dataset_serialization_test_base.py
示例16: importer
def importer():
saver_lib.import_meta_graph(save_prefix + '.meta')
return ops.get_default_graph().as_graph_element('output:0')
开发者ID:aritratony,项目名称:tensorflow,代码行数:3,代码来源:wrap_function_test.py
示例17: load_session_bundle_from_path
def load_session_bundle_from_path(export_dir,
target="",
config=None,
meta_graph_def=None):
"""Load session bundle from the given path.
The function reads input from the export_dir, constructs the graph data to the
default graph and restores the parameters for the session created.
Args:
export_dir: the directory that contains files exported by exporter.
target: The execution engine to connect to. See target in
tf.compat.v1.Session()
config: A ConfigProto proto with configuration options. See config in
tf.compat.v1.Session()
meta_graph_def: optional object of type MetaGraphDef. If this object is
present, then it is used instead of parsing MetaGraphDef from export_dir.
Returns:
session: a tensorflow session created from the variable files.
meta_graph: a meta graph proto saved in the exporter directory.
Raises:
RuntimeError: if the required files are missing or contain unrecognizable
fields, i.e. the exported model is invalid.
"""
if not meta_graph_def:
meta_graph_filename = os.path.join(export_dir,
constants.META_GRAPH_DEF_FILENAME)
if not file_io.file_exists(meta_graph_filename):
raise RuntimeError("Expected meta graph file missing %s" %
meta_graph_filename)
# Reads meta graph file.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
meta_graph_def.ParseFromString(
file_io.read_file_to_string(meta_graph_filename, binary_mode=True))
variables_filename = ""
variables_filename_list = []
checkpoint_sharded = False
variables_index_filename = os.path.join(export_dir,
constants.VARIABLES_INDEX_FILENAME_V2)
checkpoint_v2 = file_io.file_exists(variables_index_filename)
# Find matching checkpoint files.
if checkpoint_v2:
# The checkpoint is in v2 format.
variables_filename_pattern = os.path.join(
export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
variables_filename_list = file_io.get_matching_files(
variables_filename_pattern)
checkpoint_sharded = True
else:
variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
if file_io.file_exists(variables_filename):
variables_filename_list = [variables_filename]
else:
variables_filename = os.path.join(export_dir,
constants.VARIABLES_FILENAME_PATTERN)
variables_filename_list = file_io.get_matching_files(variables_filename)
checkpoint_sharded = True
# Prepare the files to restore a session.
if not variables_filename_list:
restore_files = ""
elif checkpoint_v2 or not checkpoint_sharded:
# For checkpoint v2 or v1 with non-sharded files, use "export" to restore
# the session.
restore_files = constants.VARIABLES_FILENAME
else:
restore_files = constants.VARIABLES_FILENAME_PATTERN
assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)
collection_def = meta_graph_def.collection_def
graph_def = graph_pb2.GraphDef()
if constants.GRAPH_KEY in collection_def:
# Use serving graph_def in MetaGraphDef collection_def if exists
graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
if len(graph_def_any) != 1:
raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
meta_graph_def)
else:
graph_def_any[0].Unpack(graph_def)
# Replace the graph def in meta graph proto.
meta_graph_def.graph_def.CopyFrom(graph_def)
ops.reset_default_graph()
sess = session.Session(target, graph=None, config=config)
# Import the graph.
saver = saver_lib.import_meta_graph(meta_graph_def)
# Restore the session.
if restore_files:
saver.restore(sess, os.path.join(export_dir, restore_files))
init_op_tensor = None
if constants.INIT_OP_KEY in collection_def:
init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
if len(init_ops) != 1:
#.........这里部分代码省略.........
开发者ID:ahmedsaiduk,项目名称:tensorflow,代码行数:101,代码来源:session_bundle.py
示例18: freeze_graph_with_def_protos
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants."""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not saver_lib.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
if input_meta_graph_def:
for node in input_meta_graph_def.graph_def.node:
node.device = ""
elif input_graph_def:
for node in input_graph_def.node:
node.device = ""
if input_graph_def:
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(saver_def=input_saver_def,
write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
input_meta_graph_def, clear_devices=True)
restorer.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
elif input_saved_model_dir:
if saved_model_tags is None:
saved_model_tags = []
loader.load(sess, saved_model_tags, input_saved_model_dir)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list,
write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.split(","))
variable_names_whitelist = (variable_names_whitelist.split(",")
if variable_names_whitelist else None)
variable_names_blacklist = (variable_names_blacklist.split(",")
if variable_names_blacklist else None)
if input_meta_graph_def:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_meta_graph_def.graph_def,
output_node_names.split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
else:
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
# Write GraphDef to file if output path has been given.
if output_graph:
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
return output_graph_def
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:99,代码来源:freeze_graph.py
示例19: freeze_graph_with_def_protos
def freeze_graph_with_def_protos(input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph_def=None,
input_saved_model_dir=None,
saved_model_tags=None,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants.
Args:
input_graph_def: A `GraphDef`.
input_saver_def: A `SaverDef` (optional).
input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
output_node_names: The name(s) of the output nodes, comma separated.
restore_op_name: Unused.
filename_tensor_name: Unused.
output_graph: String where to write the frozen `GraphDef`.
clear_devices: A Bool whether to remove device specifications.
initializer_nodes: Comma separated string of initializer nodes to run before
freezing.
variable_names_whitelist: The set of variable names to convert (optional, by
default, all variables are converted).
variable_names_blacklist: The set of variable names to omit converting
to constants (optional).
input_meta_graph_def: A `MetaGraphDef` (optional),
input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
and variables (optional).
saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
load, in string format (optional).
checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
or saver_pb2.SaverDef.V2)
Returns:
Location of the output_graph_def.
"""
del restore_op_name, filename_tensor_name # Unused by updated loading code.
# 'input_checkpoint' may be a prefix if we're using Saver V2 format
if (not input_saved_model_dir and
not checkpoint_management.checkpoint_exists(input_checkpoint)):
print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
return -1
if not output_node_names:
print("You need to supply the name of a node to --output_node_names.")
return -1
# Remove all the explicit device specifications for this node. This helps to
# make the graph more portable.
if clear_devices:
if input_meta_graph_def:
for node in input_meta_graph_def.graph_def.node:
node.device = ""
elif input_graph_def:
for node in input_graph_def.node:
node.device = ""
if input_graph_def:
_ = importer.import_graph_def(input_graph_def, name="")
with session.Session() as sess:
if input_saver_def:
saver = saver_lib.Saver(
saver_def=input_saver_def, write_version=checkpoint_version)
saver.restore(sess, input_checkpoint)
elif input_meta_graph_def:
restorer = saver_lib.import_meta_graph(
input_meta_graph_def, clear_devices=True)
restorer.restore(sess, input_checkpoint)
if initializer_nodes:
sess.run(initializer_nodes.replace(" ", "").split(","))
elif input_saved_model_dir:
if saved_model_tags is None:
saved_model_tags = []
loader.load(sess, saved_model_tags, input_saved_model_dir)
else:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
# List of all partition variables. Because the condition is heuristic
# based, the list could include false positives.
all_parition_variable_names = [
tensor.name.split(":")[0]
for op in sess.graph.get_operations()
for tensor in op.values()
if re.search(r"/part_\d+/", tensor.name)
]
has_partition_var = False
#.........这里部分代码省略.........
开发者ID:AnishShah,项目名称:tensorflow,代码行数:101,代码来源:freeze_graph.py
示例20: doBasicsOneExportPath
def doBasicsOneExportPath(self,
export_path,
clear_devices=False,
global_step=GLOBAL_STEP,
sharded=True,
export_count=1):
# Build a graph with 2 parameter nodes on different devices.
ops.reset_default_graph()
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
# v2 is an unsaved variable derived from v0 and v1. It is used to
# exercise the ability to run an init op when restoring a graph.
with sess.graph.device("/cpu:0"):
v0 = variables.VariableV1(10, name="v0")
with sess.graph.device("/cpu:1"):
v1 = variables.VariableV1(20, name="v1")
v2 = variables.VariableV1(1, name="v2", trainable=False, collections=[])
assign_v2 = state_ops.assign(v2, math_ops.add(v0, v1))
init_op = control_flow_ops.group(assign_v2, name="init_op")
ops.add_to_collection("v", v0)
ops.add_to_collection("v", v1)
ops.add_to_collection("v", v2)
named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
signatures = {
"foo":
exporter.regression_signature(
input_tensor=v0, output_tensor=v1),
"generic":
exporter.generic_signature(named_tensor_bindings)
}
asset_filepath_orig = os.path.join(test.get_temp_dir(), "hello42.txt")
asset_file = constant_op.constant(asset_filepath_orig, name="filename42")
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file)
with gfile.FastGFile(asset_filepath_orig, "w") as f:
f.write("your data here")
assets_col
|
请发表评论