• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python importer.import_graph_def函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.importer.import_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python import_graph_def函数的具体用法?Python import_graph_def怎么用?Python import_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了import_graph_def函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: main

def main(_):
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      metagraph.ParseFromString(meta_file.read())
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      fetch = graph.get_operation_by_name(FLAGS.fetch)
      graph.add_to_collection("train_op", fetch)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:27,代码来源:cost_analyzer_tool.py


示例2: testDefaultAttrsRemoved

  def testDefaultAttrsRemoved(self):
    producer_op_list = op_def_pb2.OpList()
    text_format.Merge("""
      op {
        name: 'OpWithFutureDefaultAttr'
        attr { name: 'default_int' type: 'int' default_value { i: 456 } }
      }
    """, producer_op_list)
    # Attr only in producer_op_list with default value gets removed.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 456 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      with self.assertRaisesRegexp(ValueError, "No attr named 'default_int'"):
        a[0].get_attr("default_int")

    # Attr only in producer_op_list with non-default value is preserved.
    with ops.Graph().as_default():
      a = importer.import_graph_def(
          self._MakeGraphDef("""
          node { name: 'A' op: 'OpWithFutureDefaultAttr'
                 attr { key: 'default_int' value { i: 987 } } }
          """),
          return_elements=["A"],
          producer_op_list=producer_op_list)
      self.assertEqual(987, a[0].get_attr("default_int"))
开发者ID:pcm17,项目名称:tensorflow,代码行数:30,代码来源:importer_test.py


示例3: get_metagraph

def get_metagraph():
  """Constructs and returns a MetaGraphDef from the input file."""
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      if FLAGS.metagraphdef.endswith(".pbtxt"):
        text_format.Merge(meta_file.read(), metagraph)
      else:
        metagraph.ParseFromString(meta_file.read())
    if FLAGS.fetch is not None:
      fetch_collection = meta_graph_pb2.CollectionDef()
      for fetch in FLAGS.fetch.split(","):
        fetch_collection.node_list.value.append(fetch)
      metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      for fetch in FLAGS.fetch.split(","):
        fetch_op = graph.get_operation_by_name(fetch)
        graph.add_to_collection("train_op", fetch_op)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)
  return metagraph
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:cost_analyzer_tool.py


示例4: testWithDeviceFunctionDependingOnInputs

  def testWithDeviceFunctionDependingOnInputs(self):
    if ops._USE_C_API: return  # TODO(skyewm): make this work with C API

    with ops.Graph().as_default() as g:
      with ops.device("/job:ps"):
        v1 = constant_op.constant(1.0)
        v2 = constant_op.constant(1.0)
      _ = v1 + v2
      _ = v1 - v2
      _ = array_ops.identity(v1)
    gdef = g.as_graph_def()

    # We'll use the following device function to observe ops with two inputs.
    ops_with_two_inputs = []

    def InputCounter(op):
      if len(op.inputs) == 2:
        ops_with_two_inputs.append(op)
      return ""

    with ops.Graph().as_default() as g:
      with ops.device(InputCounter):
        importer.import_graph_def(gdef)

    # We expect to see the add and subtract, but not identity.
    self.assertEqual(2, len(ops_with_two_inputs))
开发者ID:dansbecker,项目名称:tensorflow,代码行数:26,代码来源:importer_test.py


示例5: run_graph_def

def run_graph_def(graph_def, input_map, outputs):
  graph = ops_lib.Graph()
  with graph.as_default():
    importer.import_graph_def(graph_def, input_map={}, name="")
  with session.Session(graph=graph) as sess:
    results = sess.run(outputs, feed_dict=input_map)
  return results
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:7,代码来源:quantize_graph_test.py


示例6: testInvalidInputForInputMap

 def testInvalidInputForInputMap(self):
   with ops.Graph().as_default():
     with self.assertRaises(TypeError) as e:
       importer.import_graph_def(
           self._MakeGraphDef(""), input_map=[constant_op.constant(5.0)])
     self.assertEqual("input_map must be a dictionary mapping strings to "
                      "Tensor objects.", str(e.exception))
   graph_def = self._MakeGraphDef("""
        node { name: 'a' op: 'Placeholder'
               attr { key: 'dtype' value { type: DT_FLOAT } }}
        node { name: 'id' op: 'Identity' input: 'a:0'
               attr { key: 'T' value { type: DT_FLOAT } }}""")
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           graph_def,
           input_map={"a:0": variables.Variable(5.0)},
           name="")
     self.assertStartsWith(str(e.exception),
                           "tf.import_graph_def() requires a non-empty `name` "
                           "if `input_map` contains non-Tensor values.")
   with ops.Graph().as_default():
     t, = importer.import_graph_def(
         graph_def,
         input_map={"a:0": constant_op.constant(5.0)},
         name="",
         return_elements=["id:0"])
     with self.test_session():
       self.assertEqual(5.0, t.eval())
开发者ID:pcm17,项目名称:tensorflow,代码行数:29,代码来源:importer_test.py


示例7: testImportGraphWithFunctionTwice

  def testImportGraphWithFunctionTwice(self):
    g = ops.Graph()
    with g.as_default():

      @function.Defun()
      def Add2(x, y):
        return math_ops.add(x, y)

      x = array_ops.placeholder(dtype=dtypes.float32, name="x")
      y = array_ops.placeholder(dtype=dtypes.float32, name="y")
      _ = Add2(x, y, name="z")  # pylint: disable=unexpected-keyword-arg

    gdef = g.as_graph_def()

    x = random_ops.random_uniform(dtype=dtypes.float32, shape=())
    y = random_ops.random_uniform(dtype=dtypes.float32, shape=())
    input_map = {"x:0": x, "y:0": y}

    with ops.name_scope("first"):
      z1 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                     input_map=input_map)[0]

    with ops.name_scope("second"):
      z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
                                     input_map=input_map)[0]

    with self.test_session() as sess:
      z1_val, z2_val = sess.run((z1, z2))
      self.assertAllEqual(z1_val, z2_val)
开发者ID:clsung,项目名称:tensorflow,代码行数:29,代码来源:importer_test.py


示例8: testNamePrefixColocationAttrsMultipleImport

  def testNamePrefixColocationAttrsMultipleImport(self):
    if ops._USE_C_API: return  # TODO(skyewm): set uniquify_names

    original_graph_def = self._MakeGraphDef("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }""")

    with ops.Graph().as_default():
      b, = importer.import_graph_def(
          original_graph_def, return_elements=["B"], name="")
      _, = importer.import_graph_def(
          original_graph_def, return_elements=["B"], name="")
      self.assertProtoEqualsVersion("""
          node { name: 'A' op: 'None' }
          node { name: 'B' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A' } }
          } }
          node { name: 'A_1' op: 'None' }
          node { name: 'B_1' op: 'None'  attr {
            key: '_class'
            value { list { s: 'loc:@A_1' } }
          } }""", b.graph.as_graph_def())
开发者ID:dansbecker,项目名称:tensorflow,代码行数:26,代码来源:importer_test.py


示例9: testMissingInputOpInGraphDef

 def testMissingInputOpInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'If' input: 'A:0' }
           """))
     self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例10: testInvalidTensorNameInGraphDef

 def testInvalidTensorNameInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  "Node 'B': Unknown input node 'A:B:0'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: 'A:B:0' }
           """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例11: testMissingControlInputInGraphDef

 def testMissingControlInputInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  r"Node 'B': Unknown input node '\^A'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: '^A' }
           """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例12: testMissingInputOpInGraphDef

 def testMissingInputOpInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(ValueError,
                                  "Node 'B': Unknown input node 'A:0'"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'FloatInput' input: 'A:0' }
           """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例13: testVersionHigh

 def testVersionHigh(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError,
         r"GraphDef min consumer version %d above current version %d "
         r"for TensorFlow \S+\.  Please upgrade TensorFlow\.$" %
         (1 << 30, versions.GRAPH_DEF_VERSION)):
       importer.import_graph_def(self._MakeGraphDef("", min_consumer=1 << 30))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例14: testVersionLow

 def testVersionLow(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         Exception,
         r"GraphDef producer version -1 below min producer %d supported "
         r"by TensorFlow \S+\.  Please regenerate your graph.$" %
         versions.GRAPH_DEF_VERSION_MIN_PRODUCER):
       importer.import_graph_def(self._MakeGraphDef("", producer=-1))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例15: testMissingControlInputInGraphDef

 def testMissingControlInputInGraphDef(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'B' op: 'None' input: '^A' }
           """))
     self.assertTrue("Control input '^A' not found" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例16: testDuplicateOperationNames

 def testDuplicateOperationNames(self):
   with self.assertRaisesRegexp(ValueError, "Node 'A' is not unique"):
     importer.import_graph_def(
         self._MakeGraphDef("""
         node { name: 'A' op: 'IntOutput' }
         node { name: 'B' op: 'IntOutput' }
         node { name: 'A' op: 'IntOutput' }
         """))
开发者ID:clsung,项目名称:tensorflow,代码行数:8,代码来源:importer_test.py


示例17: SetProducerVersion

 def SetProducerVersion(self, graph, producer_version):
   # The C API doesn't expose altering GraphDefVersions. We can indirectly set
   # it via import_graph_def though.
   graph_def = graph_pb2.GraphDef()
   graph_def.versions.producer = producer_version
   with graph.as_default():
     importer.import_graph_def(graph_def)
   assert graph.graph_def_versions.producer, producer_version
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:nn_batchnorm_test.py


示例18: _TestTrtGraphConverter

  def _TestTrtGraphConverter(self,
                             input_saved_model_dir=None,
                             output_saved_model_dir=None,
                             need_calibration=False,
                             is_dynamic_op=False):
    """General method to test trt_convert.TrtGraphConverter()."""
    output_graph_def = self._ConvertGraph(
        input_saved_model_dir=input_saved_model_dir,
        output_saved_model_dir=output_saved_model_dir,
        need_calibration=need_calibration,
        is_dynamic_op=is_dynamic_op,
        use_function_backup=need_calibration)
    graph_defs_to_verify = [output_graph_def]

    if output_saved_model_dir:
      if context.executing_eagerly():
        root = load.load(output_saved_model_dir)
        saved_model_graph_def = root.signatures[
            signature_constants
            .DEFAULT_SERVING_SIGNATURE_DEF_KEY].graph.as_graph_def()
      else:
        saved_model_graph_def = saved_model_utils.get_meta_graph_def(
            output_saved_model_dir, tag_constants.SERVING).graph_def
      self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
      graph_defs_to_verify.append(saved_model_graph_def)

    for graph_def in graph_defs_to_verify:
      node_name_to_op = {node.name: node.op for node in graph_def.node}
      if context.executing_eagerly():
        # In V2 the actual graph could be inside a function.
        for func in graph_def.library.function:
          node_name_to_op.update({node.name: node.op for node in func.node_def})
        self.assertIn("TRTEngineOp_0", node_name_to_op)
        self.assertEqual("TRTEngineOp", node_name_to_op["TRTEngineOp_0"])
      else:
        self.assertEqual({
            "input": "Placeholder",
            "TRTEngineOp_0": "TRTEngineOp",
            "output": "Identity"
        }, node_name_to_op)

      if need_calibration:
        trt_engine_nodes = [
            node for node in graph_def.node if node.op == "TRTEngineOp"
        ]
        self.assertNotEmpty(trt_engine_nodes)
        for node in trt_engine_nodes:
          self.assertTrue(len(node.attr["calibration_data"].s))
        # Run the calibrated graph.
        # TODO(laigd): consider having some input where the answer is different.
        with ops.Graph().as_default():
          importer.import_graph_def(graph_def, name="")
          with self.session(config=self._GetConfigProto()) as sess:
            for test_data in range(10):
              self.assertEqual((test_data + 1.0)**2,
                               sess.run(
                                   "output:0",
                                   feed_dict={"input:0": [[[test_data]]]}))
开发者ID:perfmjs,项目名称:tensorflow,代码行数:58,代码来源:trt_convert_test.py


示例19: testMissingReturnOperation

 def testMissingReturnOperation(self):
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "Requested return node 'B' not found in graph def"):
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'None' }
           """),
           return_elements=["B"])
开发者ID:clsung,项目名称:tensorflow,代码行数:9,代码来源:importer_test.py


示例20: testMissingInputMap

 def testMissingInputMap(self):
   with ops.Graph().as_default():
     with self.assertRaises(ValueError) as e:
       importer.import_graph_def(
           self._MakeGraphDef("""
           node { name: 'A' op: 'None' }
           """),
           input_map={"B:0": constant_op.constant(5.0)})
     self.assertTrue("not found in graph_def: [B:0]" in str(e.exception))
开发者ID:pcm17,项目名称:tensorflow,代码行数:9,代码来源:importer_test.py



注:本文中的tensorflow.python.framework.importer.import_graph_def函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python load_library.load_op_library函数代码示例发布时间:2022-05-27
下一篇:
Python graph_util.extract_sub_graph函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap