本文整理汇总了Python中tensorflow.python.framework.importer.import_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python import_graph_def函数的具体用法?Python import_graph_def怎么用?Python import_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了import_graph_def函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: main
def main(_):
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
metagraph.ParseFromString(meta_file.read())
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
if FLAGS.graphdef.endswith(".pbtxt"):
text_format.Merge(graph_file.read(), graph_def)
else:
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
fetch = graph.get_operation_by_name(FLAGS.fetch)
graph.add_to_collection("train_op", fetch)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
text_format.Merge(FLAGS.rewriter_config, rewriter_config)
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
metagraph.graph_def.CopyFrom(optimized_graph)
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
print(report)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:27,代码来源:cost_analyzer_tool.py
示例2: testDefaultAttrsRemoved
def testDefaultAttrsRemoved(self):
producer_op_list = op_def_pb2.OpList()
text_format.Merge("""
op {
name: 'OpWithFutureDefaultAttr'
attr { name: 'default_int' type: 'int' default_value { i: 456 } }
}
""", producer_op_list)
# Attr only in producer_op_list with default value gets removed.
with ops.Graph().as_default():
a = importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithFutureDefaultAttr'
attr { key: 'default_int' value { i: 456 } } }
"""),
return_elements=["A"],
producer_op_list=producer_op_list)
with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
a[0].get_attr("default_int")
# Attr only in producer_op_list with non-default value is preserved.
with ops.Graph().as_default():
a = importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'OpWithFutureDefaultAttr'
attr { key: 'default_int' value { i: 987 } } }
"""),
return_elements=["A"],
producer_op_list=producer_op_list)
self.assertEqual(987, a[0].get_attr("default_int"))
开发者ID:pcm17,项目名称:tensorflow,代码行数:30,代码来源:importer_test.py
示例3: get_metagraph
def get_metagraph():
"""Constructs and returns a MetaGraphDef from the input file."""
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
if FLAGS.metagraphdef.endswith(".pbtxt"):
text_format.Merge(meta_file.read(), metagraph)
else:
metagraph.ParseFromString(meta_file.read())
if FLAGS.fetch is not None:
fetch_collection = meta_graph_pb2.CollectionDef()
for fetch in FLAGS.fetch.split(","):
fetch_collection.node_list.value.append(fetch)
metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
if FLAGS.graphdef.endswith(".pbtxt"):
text_format.Merge(graph_file.read(), graph_def)
else:
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
for fetch in FLAGS.fetch.split(","):
fetch_op = graph.get_operation_by_name(fetch)
graph.add_to_collection("train_op", fetch_op)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
return metagraph
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:cost_analyzer_tool.py
示例4: testWithDeviceFunctionDependingOnInputs
def testWithDeviceFunctionDependingOnInputs(self):
if ops._USE_C_API: return # TODO(skyewm): make this work with C API
with ops.Graph().as_default() as g:
with ops.device("/job:ps"):
v1 = constant_op.constant(1.0)
v2 = constant_op.constant(1.0)
_ = v1 + v2
_ = v1 - v2
_ = array_ops.identity(v1)
gdef = g.as_graph_def()
# We'll use the following device function to observe ops with two inputs.
ops_with_two_inputs = []
def InputCounter(op):
if len(op.inputs) == 2:
ops_with_two_inputs.append(op)
return ""
with ops.Graph().as_default() as g:
with ops.device(InputCounter):
importer.import_graph_def(gdef)
# We expect to see the add and subtract, but not identity.
self.assertEqual(2, len(ops_with_two_inputs))
开发者ID:dansbecker,项目名称:tensorflow,代码行数:26,代码来源:importer_test.py
示例5: run_graph_def
def run_graph_def(graph_def, input_map, outputs):
graph = ops_lib.Graph()
with graph.as_default():
importer.import_graph_def(graph_def, input_map={}, name="")
with session.Session(graph=graph) as sess:
results = sess.run(outputs, feed_dict=input_map)
return results
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:7,代码来源:quantize_graph_test.py
示例6: testInvalidInputForInputMap
def testInvalidInputForInputMap(self):
with ops.Graph().as_default():
with self.assertRaises(TypeError) as e:
importer.import_graph_def(
self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
self.assertEqual("input_map must be a dictionary mapping strings to "
"Tensor objects.", str(e.exception))
graph_def = self._MakeGraphDef("""
node { name: 'a' op: 'Placeholder'
attr { key: 'dtype' value { type: DT_FLOAT } }}
node { name: 'id' op: 'Identity' input: 'a:0'
attr { key: 'T' value { type: DT_FLOAT } }}""")
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
graph_def,
input_map={"a:0": variables.Variable(5.0)},
name="")
self.assertStartsWith(str(e.exception),
"tf.import_graph_def() requires a non-empty `name` "
"if `input_map` contains non-Tensor values.")
with ops.Graph().as_default():
t, = importer.import_graph_def(
graph_def,
input_map={"a:0": constant_op.constant(5.0)},
name="",
return_elements=["id:0"])
with self.test_session():
self.assertEqual(5.0, t.eval())
开发者ID:pcm17,项目名称:tensorflow,代码行数:29,代码来源:importer_test.py
示例7: testImportGraphWithFunctionTwice
def testImportGraphWithFunctionTwice(self):
g = ops.Graph()
with g.as_default():
@function.Defun()
def Add2(x, y):
return math_ops.add(x, y)
x = array_ops.placeholder(dtype=dtypes.float32, name="x")
y = array_ops.placeholder(dtype=dtypes.float32, name="y")
_ = Add2(x, y, name="z") # pylint: disable=unexpected-keyword-arg
gdef = g.as_graph_def()
x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
input_map = {"x:0": x, "y:0": y}
with ops.name_scope("first"):
z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
with ops.name_scope("second"):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
with self.test_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
开发者ID:clsung,项目名称:tensorflow,代码行数:29,代码来源:importer_test.py
示例8: testNamePrefixColocationAttrsMultipleImport
def testNamePrefixColocationAttrsMultipleImport(self):
if ops._USE_C_API: return # TODO(skyewm): set uniquify_names
original_graph_def = self._MakeGraphDef("""
node { name: 'A' op: 'None' }
node { name: 'B' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A' } }
} }""")
with ops.Graph().as_default():
b, = importer.import_graph_def(
original_graph_def, return_elements=["B"], name="")
_, = importer.import_graph_def(
original_graph_def, return_elements=["B"], name="")
self.assertProtoEqualsVersion("""
node { name: 'A' op: 'None' }
node { name: 'B' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A' } }
} }
node { name: 'A_1' op: 'None' }
node { name: 'B_1' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A_1' } }
} }""", b.graph.as_graph_def())
开发者ID:dansbecker,项目名称:tensorflow,代码行数:26,代码来源:importer_test.py
示例9: testMissingInputOpInGraphDef
def testMissingInputOpInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'If' input: 'A:0' }
"""))
self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例10: testInvalidTensorNameInGraphDef
def testInvalidTensorNameInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
"Node 'B': Unknown input node 'A:B:0'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: 'A:B:0' }
"""))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例11: testMissingControlInputInGraphDef
def testMissingControlInputInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
r"Node 'B': Unknown input node '\^A'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: '^A' }
"""))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例12: testMissingInputOpInGraphDef
def testMissingInputOpInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
"Node 'B': Unknown input node 'A:0'"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'FloatInput' input: 'A:0' }
"""))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例13: testVersionHigh
def testVersionHigh(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
r"GraphDef min consumer version %d above current version %d "
r"for TensorFlow \S+\. Please upgrade TensorFlow\.$" %
(1 << 30, versions.GRAPH_DEF_VERSION)):
importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例14: testVersionLow
def testVersionLow(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
Exception,
r"GraphDef producer version -1 below min producer %d supported "
r"by TensorFlow \S+\. Please regenerate your graph.$" %
versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
importer.import_graph_def(self._MakeGraphDef("", producer=-1))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例15: testMissingControlInputInGraphDef
def testMissingControlInputInGraphDef(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'B' op: 'None' input: '^A' }
"""))
self.assertTrue("Control input '^A' not found" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例16: testDuplicateOperationNames
def testDuplicateOperationNames(self):
with self.assertRaisesRegexp(ValueError, "Node 'A' is not unique"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
node { name: 'B' op: 'IntOutput' }
node { name: 'A' op: 'IntOutput' }
"""))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py
示例17: SetProducerVersion
def SetProducerVersion(self, graph, producer_version):
# The C API doesn't expose altering GraphDefVersions. We can indirectly set
# it via import_graph_def though.
graph_def = graph_pb2.GraphDef()
graph_def.versions.producer = producer_version
with graph.as_default():
importer.import_graph_def(graph_def)
assert graph.graph_def_versions.producer, producer_version
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:nn_batchnorm_test.py
示例18: _TestTrtGraphConverter
def _TestTrtGraphConverter(self,
input_saved_model_dir=None,
output_saved_model_dir=None,
need_calibration=False,
is_dynamic_op=False):
"""General method to test trt_convert.TrtGraphConverter()."""
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
need_calibration=need_calibration,
is_dynamic_op=is_dynamic_op,
use_function_backup=need_calibration)
graph_defs_to_verify = [output_graph_def]
if output_saved_model_dir:
if context.executing_eagerly():
root = load.load(output_saved_model_dir)
saved_model_graph_def = root.signatures[
signature_constants
.DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
else:
saved_model_graph_def = saved_model_utils.get_meta_graph_def(
output_saved_model_dir, tag_constants.SERVING).graph_def
self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
graph_defs_to_verify.append(saved_model_graph_def)
for graph_def in graph_defs_to_verify:
node_name_to_op = {node.name: node.op for node in graph_def.node}
if context.executing_eagerly():
# In V2 the actual graph could be inside a function.
for func in graph_def.library.function:
node_name_to_op.update({node.name: node.op for node in func.node_def})
self.assertIn("TRTEngineOp_0", node_name_to_op)
self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
else:
self.assertEqual({
"input": "Placeholder",
"TRTEngineOp_0": "TRTEngineOp",
"output": "Identity"
}, node_name_to_op)
if need_calibration:
trt_engine_nodes = [
node for node in graph_def.node if node.op == "TRTEngineOp"
]
self.assertNotEmpty(trt_engine_nodes)
for node in trt_engine_nodes:
self.assertTrue(len(node.attr["calibration_data"].s))
# Run the calibrated graph.
# TODO(laigd): consider having some input where the answer is different.
with ops.Graph().as_default():
importer.import_graph_def(graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
for test_data in range(10):
self.assertEqual((test_data + 1.0)**2,
sess.run(
"output:0",
feed_dict={"input:0": [[[test_data]]]}))
开发者ID:perfmjs,项目名称:tensorflow,代码行数:58,代码来源:trt_convert_test.py
示例19: testMissingReturnOperation
def testMissingReturnOperation(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError, "Requested return node 'B' not found in graph def"):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
"""),
return_elements=["B"])
开发者ID:clsung,项目名称:tensorflow,代码行数:9,代码来源:importer_test.py
示例20: testMissingInputMap
def testMissingInputMap(self):
with ops.Graph().as_default():
with self.assertRaises(ValueError) as e:
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'None' }
"""),
input_map={"B:0": constant_op.constant(5.0)})
self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:9,代码来源:importer_test.py
注:本文中的tensorflow.python.framework.importer.import_graph_def函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论