本文整理汇总了Python中tensorflow.python.tools.saved_model_utils.get_meta_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python get_meta_graph_def函数的具体用法?Python get_meta_graph_def怎么用?Python get_meta_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_meta_graph_def函数的10个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _TestCreateInferenceGraph
def _TestCreateInferenceGraph(self,
input_saved_model_dir=None,
output_saved_model_dir=None):
"""General method to test trt_convert.create_inference_graph()."""
input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
output_graph_def = trt_convert.create_inference_graph(
input_graph_def, ["output"],
max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
session_config=self._GetConfigProto())
graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir is not None:
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
output_saved_model_dir, tag_constants.SERVING).graph_def
self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
self.assertEqual({
"input": "Placeholder",
"TRTEngineOp_0": "TRTEngineOp",
"output": "Identity"
}, node_name_to_op)
开发者ID:ziky90,项目名称:tensorflow,代码行数:25,代码来源:trt_convert_test.py
示例2: _show_inputs_outputs
def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key):
"""Prints input and output TensorInfos.
Prints the details of input and output TensorInfos for the SignatureDef mapped
by the given signature_def_key.
Args:
saved_model_dir: Directory containing the SavedModel to inspect.
tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by
','. For tag-set contains multiple tags, all tags must be passed in.
signature_def_key: A SignatureDef key string.
"""
meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
tag_set)
inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
print('The given SavedModel SignatureDef contains the following input(s):')
for input_key, input_tensor in sorted(inputs_tensor_info.items()):
print('inputs[\'%s\'] tensor_info:' % input_key)
_print_tensor_info(input_tensor)
print('The given SavedModel SignatureDef contains the following output(s):')
for output_key, output_tensor in sorted(outputs_tensor_info.items()):
print('outputs[\'%s\'] tensor_info:' % output_key)
_print_tensor_info(output_tensor)
print('Method name is: %s' %
meta_graph_def.signature_def[signature_def_key].method_name)
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:31,代码来源:saved_model_cli.py
示例3: freeze_graph
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING):
"""Converts all variables in a graph and checkpoint into constants."""
input_graph_def = None
if input_saved_model_dir:
input_graph_def = saved_model_utils.get_meta_graph_def(
input_saved_model_dir, saved_model_tags).graph_def
elif input_graph:
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
input_meta_graph_def = None
if input_meta_graph:
input_meta_graph_def = _parse_input_meta_graph_proto(
input_meta_graph, input_binary)
input_saver_def = None
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
freeze_graph_with_def_protos(
input_graph_def, input_saver_def, input_checkpoint, output_node_names,
restore_op_name, filename_tensor_name, output_graph, clear_devices,
initializer_nodes, variable_names_whitelist, variable_names_blacklist,
input_meta_graph_def, input_saved_model_dir, saved_model_tags.split(","))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:34,代码来源:freeze_graph.py
示例4: _TestTrtGraphConverter
def _TestTrtGraphConverter(self,
input_saved_model_dir=None,
output_saved_model_dir=None,
need_calibration=False,
is_dynamic_op=False):
"""General method to test trt_convert.TrtGraphConverter()."""
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
need_calibration=need_calibration,
is_dynamic_op=is_dynamic_op,
use_function_backup=need_calibration)
graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir:
if context.executing_eagerly():
root = load.load(output_saved_model_dir)
saved_model_graph_def = root.signatures[
signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
else:
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
output_saved_model_dir, tag_constants.SERVING).graph_def
self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
if context.executing_eagerly():
# In V2 the actual graph could be inside a function.
for func in graph_def.library.function:
node_name_to_op.update({node.name: node.op for node in func.node_def})
self.assertIn("TRTEngineOp_0", node_name_to_op)
self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
else:
self.assertEqual({
"input": "Placeholder",
"TRTEngineOp_0": "TRTEngineOp",
"output": "Identity"
}, node_name_to_op)
if need_calibration:
trt_engine_nodes = [
node for node in graph_def.node if node.op == "TRTEngineOp"
]
self.assertNotEmpty(trt_engine_nodes)
for node in trt_engine_nodes:
self.assertTrue(len(node.attr["calibration_data"].s))
# Run the calibrated graph.
# TODO(laigd): consider having some input where the answer is different.
with ops.Graph().as_default():
importer.import_graph_def(graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
for test_data in range(10):
self.assertEqual((test_data + 1.0)**2,
sess.run(
"output:0",
feed_dict={"input:0": [[[test_data]]]}))
开发者ID:perfmjs,项目名称:tensorflow,代码行数:58,代码来源:trt_convert_test.py
示例5: scan
def scan(args):
"""Function triggered by scan command.
Args:
args: A namespace parsed from command line.
"""
if args.tag_set:
scan_meta_graph_def(
saved_model_utils.get_meta_graph_def(args.dir, args.tag_set))
else:
saved_model = reader.read_saved_model(args.dir)
for meta_graph_def in saved_model.meta_graphs:
scan_meta_graph_def(meta_graph_def)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:13,代码来源:saved_model_cli.py
示例6: get_signature_def_map
def get_signature_def_map(saved_model_dir, tag_set):
"""Gets SignatureDef map from a MetaGraphDef in a SavedModel.
Returns the SignatureDef map for the given tag-set in the SavedModel
directory.
Args:
saved_model_dir: Directory containing the SavedModel to inspect or execute.
tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
string format, separated by ','. For tag-set contains multiple tags, all
tags must be passed in.
Returns:
A SignatureDef map that maps from string keys to SignatureDefs.
"""
meta_graph = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
return meta_graph.signature_def
开发者ID:DILASSS,项目名称:tensorflow,代码行数:17,代码来源:saved_model_cli.py
示例7: _TestTrtGraphConverter
def _TestTrtGraphConverter(self,
input_saved_model_dir=None,
output_saved_model_dir=None,
need_calibration=False,
is_dynamic_op=False):
"""General method to test trt_convert.TrtGraphConverter()."""
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
need_calibration=need_calibration,
is_dynamic_op=is_dynamic_op,
use_function_backup=need_calibration)
graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir:
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
output_saved_model_dir, tag_constants.SERVING).graph_def
self.assertIsInstance(saved_model_graph_def, graph_pb2.GraphDef)
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
self.assertEqual(
{
"input": "Placeholder",
"TRTEngineOp_0": "TRTEngineOp",
"output": "Identity"
}, node_name_to_op)
if need_calibration:
trt_engine_nodes = [
node for node in graph_def.node if node.op == "TRTEngineOp"
]
self.assertNotEmpty(trt_engine_nodes)
for node in trt_engine_nodes:
self.assertTrue(len(node.attr["calibration_data"].s))
# Run the calibrated graph.
# TODO(laigd): consider having some input where the answer is different.
with ops.Graph().as_default():
importer.import_graph_def(graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
for test_data in range(10):
self.assertEqual(
(test_data + 1.0)**2,
sess.run("output:0", feed_dict={"input:0": [[[test_data]]]}))
开发者ID:aritratony,项目名称:tensorflow,代码行数:45,代码来源:trt_convert_test.py
示例8: get_meta_graph_def
def get_meta_graph_def(saved_model_dir, tag_set):
"""DEPRECATED: Use saved_model_utils.get_meta_graph_def instead.
Gets MetaGraphDef from SavedModel. Returns the MetaGraphDef for the given
tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel to inspect or execute.
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
separated by ','. For tag-set contains multiple tags, all tags must be
passed in.
Raises:
RuntimeError: An error when the given tag-set does not exist in the
SavedModel.
Returns:
A MetaGraphDef corresponding to the tag-set.
"""
return saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:20,代码来源:saved_model_cli.py
示例9: run_saved_model_with_feed_dict
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
overwrite_flag, tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
specified by the given tag_set and SignatureDef. Also save the outputs to file
if outdir is not None.
Args:
saved_model_dir: Directory containing the SavedModel to execute.
tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
string format, separated by ','. For tag-set contains multiple tags, all
tags must be passed in.
signature_def_key: A SignatureDef key string.
input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
outdir: A directory to save the outputs to. If the directory doesn't exist,
it will be created.
overwrite_flag: A boolean flag to allow overwrite output file if file with
the same name exists.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
Raises:
ValueError: When any of the input tensor keys is not valid.
RuntimeError: An error when output file already exists and overwrite is not
enabled.
"""
# Get a list of output tensor names.
meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
tag_set)
# Re-create feed_dict based on input tensor name instead of key as session.run
# uses tensor name.
inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
# Check if input tensor keys are valid.
for input_key_name in input_tensor_key_feed_dict.keys():
if input_key_name not in inputs_tensor_info.keys():
raise ValueError(
'"%s" is not a valid input key. Please choose from %s, or use '
'--show option.' %
(input_key_name, '"' + '", "'.join(inputs_tensor_info.keys()) + '"'))
inputs_feed_dict = {
inputs_tensor_info[key].name: tensor
for key, tensor in input_tensor_key_feed_dict.items()
}
# Get outputs
outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
meta_graph_def, signature_def_key)
# Sort to preserve order because we need to go from value to key later.
output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
output_tensor_names_sorted = [
outputs_tensor_info[tensor_key].name
for tensor_key in output_tensor_keys_sorted
]
with session.Session(graph=ops_lib.Graph()) as sess:
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)
outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)
for i, output in enumerate(outputs):
output_tensor_key = output_tensor_keys_sorted[i]
print('Result for output key %s:\n%s' % (output_tensor_key, output))
# Only save if outdir is specified.
if outdir:
# Create directory if outdir does not exist
if not os.path.isdir(outdir):
os.makedirs(outdir)
output_full_path = os.path.join(outdir, output_tensor_key + '.npy')
# If overwrite not enabled and file already exist, error out
if not overwrite_flag and os.path.exists(output_full_path):
raise RuntimeError(
'Output file %s already exists. Add \"--overwrite\" to overwrite'
' the existing output files.' % output_full_path)
np.save(output_full_path, output)
print('Output %s is saved to %s' % (output_tensor_key,
output_full_path))
开发者ID:DILASSS,项目名称:tensorflow,代码行数:88,代码来源:saved_model_cli.py
示例10: freeze_graph
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2):
"""Converts all variables in a graph and checkpoint into constants.
Args:
input_graph: A `GraphDef` file to load.
input_saver: A TensorFlow Saver file.
input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
priority. Typically the result of `Saver.save()` or that of
`tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
V1/V2.
output_node_names: The name(s) of the output nodes, comma separated.
restore_op_name: Unused.
filename_tensor_name: Unused.
output_graph: String where to write the frozen `GraphDef`.
clear_devices: A Bool whether to remove device specifications.
initializer_nodes: Comma separated list of initializer nodes to run before
freezing.
variable_names_whitelist: The set of variable names to convert (optional, by
default, all variables are converted),
variable_names_blacklist: The set of variable names to omit converting
to constants (optional).
input_meta_graph: A `MetaGraphDef` file to load (optional).
input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
variables (optional).
saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
load, in string format.
checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
or saver_pb2.SaverDef.V2).
Returns:
String that is the location of frozen GraphDef.
"""
input_graph_def = None
if input_saved_model_dir:
input_graph_def = saved_model_utils.get_meta_graph_def(
input_saved_model_dir, saved_model_tags).graph_def
elif input_graph:
input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
input_meta_graph_def = None
if input_meta_graph:
input_meta_graph_def = _parse_input_meta_graph_proto(
input_meta_graph, input_binary)
input_saver_def = None
if input_saver:
input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
freeze_graph_with_def_protos(
input_graph_def,
input_saver_def,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_whitelist,
variable_names_blacklist,
input_meta_graph_def,
input_saved_model_dir,
saved_model_tags.replace(" ", "").split(","),
checkpoint_version=checkpoint_version)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:76,代码来源:freeze_graph.py
注:本文中的tensorflow.python.tools.saved_model_utils.get_meta_graph_def函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论