• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python saver.export_meta_graph函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.saver.export_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python export_meta_graph函数的具体用法?Python export_meta_graph怎么用?Python export_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了export_meta_graph函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testNoVariables

  def testNoVariables(self):
    test_dir = _TestDir("no_variables")
    filename = os.path.join(test_dir, "metafile")

    input_feed_value = -10  # Arbitrary input value for feed_dict.

    orig_graph = tf.Graph()
    with self.test_session(graph=orig_graph) as sess:
      # Create a minimal graph with zero variables.
      input_tensor = tf.placeholder(tf.float32, shape=[], name="input")
      offset = tf.constant(42, dtype=tf.float32, name="offset")
      output_tensor = tf.add(input_tensor, offset, name="add_offset")

      # Add input and output tensors to graph collections.
      tf.add_to_collection("input_tensor", input_tensor)
      tf.add_to_collection("output_tensor", output_tensor)

      output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
      self.assertEqual(output_value, 32)

      # Generates MetaGraphDef.
      #
      # Note that this is calling the saver *module-level* export_meta_graph and
      # not the Saver.export_meta_graph instance-level method.
      meta_graph_def = saver_module.export_meta_graph(
          filename=filename,
          graph_def=tf.get_default_graph().as_graph_def(),
          collection_list=["input_tensor", "output_tensor"],
          saver_def=None,
      )

    # Create a clean graph and import the MetaGraphDef nodes.
    new_graph = tf.Graph()
    with self.test_session(graph=new_graph) as sess:
      # Import the previously export meta graph.
      saver_instance = saver_module.import_meta_graph(filename)
      # The saver instance should be None since there are no graph variables
      # to be restored in this case.
      self.assertIsNone(saver_instance)

      # Re-exports the current graph state for comparison to the original.
      new_meta_graph_def = saver_module.export_meta_graph(filename + "_new")
      self.assertProtoEquals(meta_graph_def, new_meta_graph_def)

      # Ensures that we can still get a reference to our graph collections.
      new_input_tensor = tf.get_collection("input_tensor")[0]
      new_output_tensor = tf.get_collection("output_tensor")[0]
      # Verifies that the new graph computes the same result as the original.
      new_output_value = sess.run(
          new_output_tensor, {new_input_tensor: input_feed_value})
      self.assertEqual(new_output_value, output_value)
开发者ID:2er0,项目名称:tensorflow,代码行数:51,代码来源:saver_test.py


示例2: main

def main(_):
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      metagraph.ParseFromString(meta_file.read())
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      fetch = graph.get_operation_by_name(FLAGS.fetch)
      graph.add_to_collection("train_op", fetch)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)

  if FLAGS.rewriter_config is not None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    text_format.Merge(FLAGS.rewriter_config, rewriter_config)
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
    metagraph.graph_def.CopyFrom(optimized_graph)

  report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
  print(report)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:27,代码来源:cost_analyzer_tool.py


示例3: get_metagraph

def get_metagraph():
  """Constructs and returns a MetaGraphDef from the input file."""
  if FLAGS.metagraphdef:
    with gfile.GFile(FLAGS.metagraphdef) as meta_file:
      metagraph = meta_graph_pb2.MetaGraphDef()
      if FLAGS.metagraphdef.endswith(".pbtxt"):
        text_format.Merge(meta_file.read(), metagraph)
      else:
        metagraph.ParseFromString(meta_file.read())
    if FLAGS.fetch is not None:
      fetch_collection = meta_graph_pb2.CollectionDef()
      for fetch in FLAGS.fetch.split(","):
        fetch_collection.node_list.value.append(fetch)
      metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
  else:
    with gfile.GFile(FLAGS.graphdef) as graph_file:
      graph_def = graph_pb2.GraphDef()
      if FLAGS.graphdef.endswith(".pbtxt"):
        text_format.Merge(graph_file.read(), graph_def)
      else:
        graph_def.ParseFromString(graph_file.read())
      importer.import_graph_def(graph_def, name="")
      graph = ops.get_default_graph()
      for fetch in FLAGS.fetch.split(","):
        fetch_op = graph.get_operation_by_name(fetch)
        graph.add_to_collection("train_op", fetch_op)
      metagraph = saver.export_meta_graph(
          graph_def=graph.as_graph_def(), graph=graph)
  return metagraph
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:29,代码来源:cost_analyzer_tool.py


示例4: testGradient

  def testGradient(self):
    if not test.is_gpu_available(cuda_only=True):
      self.skipTest('GPU required')

    random_seed.set_random_seed(0)
    x = random_ops.truncated_normal([1, 200, 200, 3], seed=0)
    y = conv_layers.conv2d(x, 32, [3, 3])
    z = conv_layers.conv2d(y, 32, [3, 3])
    optimizer = gradient_descent.GradientDescentOptimizer(1e-4)
    loss = math_ops.reduce_mean(z)
    train_op = optimizer.minimize(loss)
    graph = ops.get_default_graph()
    graph.add_to_collection('train_op', train_op)
    meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())

    rewrite_options = rewriter_config_pb2.RewriterConfig(
        optimize_tensor_layout=True)
    optimized_graph = tf_optimizer.OptimizeGraph(rewrite_options, meta_graph)

    found = 0
    for node in optimized_graph.node:
      if node.op in ['Conv2D', 'Conv2DBackpropFilter', 'Conv2DBackpropInput']:
        found += 1
        self.assertEqual(node.attr['data_format'].s, 'NCHW')
    self.assertEqual(found, 5)
开发者ID:SylChan,项目名称:tensorflow,代码行数:25,代码来源:layout_optimizer_test.py


示例5: _run_inline_graph_optimization

def _run_inline_graph_optimization(func):
  """Apply function inline optimization to the graph.

  Returns the GraphDef after Grappler's function inlining optimization is
  applied. This optimization does not work on models with control flow.

  Args:
    func: ConcreteFunction.

  Returns:
    GraphDef
  """
  meta_graph = export_meta_graph(
      graph_def=func.graph.as_graph_def(), graph=func.graph)

  # Add a collection 'train_op' so that Grappler knows the outputs.
  fetch_collection = meta_graph_pb2.CollectionDef()
  for array in func.inputs + func.outputs:
    fetch_collection.node_list.value.append(array.name)
  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

  # Initialize RewriterConfig with everything disabled except function inlining.
  config = config_pb2.ConfigProto()
  rewrite_options = config.graph_options.rewrite_options
  rewrite_options.optimizers.append("function")
  return tf_optimizer.OptimizeGraph(config, meta_graph)
开发者ID:kylin9872,项目名称:tensorflow,代码行数:26,代码来源:convert_to_constants.py


示例6: _ExportAndImportGraph

 def _ExportAndImportGraph(self, graph):
   """Export and import graph into a new graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:8,代码来源:moving_averages_test.py


示例7: _CopyGraph

 def _CopyGraph(self, graph):
   """Return a copy of graph."""
   meta_graph = saver_lib.export_meta_graph(
       graph=graph, collection_list=graph.get_all_collection_keys())
   graph_copy = ops.Graph()
   with graph_copy.as_default():
     _ = saver_lib.import_meta_graph(meta_graph)
   return graph_copy
开发者ID:Eagle732,项目名称:tensorflow,代码行数:8,代码来源:fold_batch_norms_test.py


示例8: testMetagraph

  def testMetagraph(self):
    with ops.Graph().as_default():
      with variable_scope.variable_scope("foo", use_resource=True):
        a = variable_scope.get_variable("a", initializer=10.0)

      momentum.MomentumOptimizer(
          learning_rate=0.001, momentum=0.1).minimize(
              a,
              colocate_gradients_with_ops=True,
              global_step=training_util.get_or_create_global_step())

      graph = ops.get_default_graph()
      meta_graph_def = saver.export_meta_graph(graph=graph)

    with ops.Graph().as_default():
      saver.import_meta_graph(meta_graph_def, import_scope="")
      meta_graph_two = saver.export_meta_graph(graph=graph)
    self.assertEqual(meta_graph_def, meta_graph_two)
开发者ID:aeverall,项目名称:tensorflow,代码行数:18,代码来源:resource_variable_ops_test.py


示例9: _convert_graph_def

  def _convert_graph_def(self):
    """Convert the input GraphDef."""
    graph = ops.Graph()
    with graph.as_default():
      importer.import_graph_def(self._input_graph_def, name="")
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
    self._add_nodes_blacklist()

    self._run_conversion()
开发者ID:aritratony,项目名称:tensorflow,代码行数:10,代码来源:trt_convert.py


示例10: setUp

  def setUp(self):
    self.base_path = os.path.join(test.get_temp_dir(), "no_vars")
    if not os.path.exists(self.base_path):
      os.mkdir(self.base_path)

    # Create a simple graph with a variable, then convert variables to
    # constants and export the graph.
    with ops.Graph().as_default() as g:
      x = array_ops.placeholder(dtypes.float32, name="x")
      w = variables.Variable(3.0)
      y = math_ops.subtract(w * x, 7.0, name="y")  # pylint: disable=unused-variable
      ops.add_to_collection("meta", "this is meta")

      with self.session(graph=g) as session:
        variables.global_variables_initializer().run()
        new_graph_def = graph_util.convert_variables_to_constants(
            session, g.as_graph_def(), ["y"])

      filename = os.path.join(self.base_path, constants.META_GRAPH_DEF_FILENAME)
      saver.export_meta_graph(
          filename, graph_def=new_graph_def, collection_list=["meta"])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:21,代码来源:session_bundle_test.py


示例11: _simple_metagraph

def _simple_metagraph(depthwise=False):
  random_seed.set_random_seed(0)
  x = variables.Variable(random_ops.truncated_normal([1, 200, 200, 3], seed=0))
  conv = conv_layers.separable_conv2d if depthwise else conv_layers.conv2d
  y = conv(x, 32, [3, 3])
  z = conv(y, 32, [3, 3])
  optimizer = gradient_descent.GradientDescentOptimizer(1e-4)
  loss = math_ops.reduce_mean(z)
  train_op = optimizer.minimize(loss)
  graph = ops.get_default_graph()
  graph.add_to_collection('train_op', train_op)
  meta_graph = saver_lib.export_meta_graph(graph_def=graph.as_graph_def())
  return meta_graph
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:13,代码来源:layout_optimizer_test.py


示例12: test_meta_graph_transform

  def test_meta_graph_transform(self):

    with ops.Graph().as_default():
      with tf_session.Session(''):
        a = array_ops.placeholder(dtypes.int64, [1], name='a')
        b = array_ops.placeholder(dtypes.int64, [1], name='b')
        c = array_ops.placeholder(dtypes.int64, [1], name='c')
        _ = a * b
        _ = b * c
        base_meta_graph_def = saver.export_meta_graph()

    with ops.Graph().as_default():
      with tf_session.Session(''):
        a = array_ops.placeholder(dtypes.int64, [1], name='a')
        b = array_ops.placeholder(dtypes.int64, [1], name='b')
        _ = a * b
        meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
        meta_info_def.tags.append('tag_ab')

        expected_meta_graph_def = saver.export_meta_graph(
            meta_info_def=meta_info_def)
        # Graph rewriter clears versions field, so we expect that.
        expected_meta_graph_def.graph_def.ClearField('versions')
        # Graph rewriter adds an empty library field, so we expect that.
        expected_meta_graph_def.graph_def.library.CopyFrom(
            function_pb2.FunctionDefLibrary())

    input_names = ['a', 'b']
    output_names = ['mul:0']
    transforms = ['strip_unused_nodes']
    tags = ['tag_ab']
    print('AAAAAA: {}'.format(base_meta_graph_def))
    transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
        base_meta_graph_def, input_names, output_names, transforms, tags)

    self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:36,代码来源:meta_graph_transform_test.py


示例13: _convert_saved_model_v2

  def _convert_saved_model_v2(self):
    """Convert the input SavedModel in 2.0 format."""
    self._saved_model = load.load(self._input_saved_model_dir,
                                  self._input_saved_model_tags)
    func = self._saved_model.signatures[self._input_saved_model_signature_key]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in func.inputs + func.outputs:
      fetch_collection.node_list.value.append(array.name)
    self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)

    # Run TRT optimizer in Grappler to convert the graph.
    self._run_conversion()

    def _get_tensor(graph, tensors):
      new_tensors = []
      for tensor in tensors:
        new_tensor = graph.get_tensor_by_name(tensor.name)
        new_tensor.set_shape(tensor.shape)
        new_tensors.append(new_tensor)
      return new_tensors

    # TODO(laigd): do we need to use different name e.g. "trt_func_graph"?
    converted_graph = func_graph.FuncGraph(func.graph.name)
    with converted_graph.as_default():
      importer.import_graph_def(self._converted_graph_def, name="")

    converted_graph.inputs = _get_tensor(converted_graph, func.graph.inputs)
    converted_graph.outputs = _get_tensor(converted_graph, func.graph.outputs)
    converted_graph.structured_outputs = func.graph.structured_outputs
    converted_graph.structured_input_signature = (
        func.graph.structured_input_signature)

    # pylint: disable=protected-access
    # TODO(laigd): should we set up the signature as well?
    self._converted_func = function.ConcreteFunction(
        converted_graph, attrs=None, signature=None)
    self._converted_func.add_to_graph()
    self._converted_func._arg_keywords = func._arg_keywords
    self._converted_func._num_positional_args = func._num_positional_args
    self._converted_func._captured_inputs = func._captured_inputs
    self._converted_func.graph.variables = func.graph.variables
开发者ID:perfmjs,项目名称:tensorflow,代码行数:47,代码来源:trt_convert.py


示例14: grappler_optimize

def grappler_optimize(graph, fetches=None, rewriter_config=None):
  """Tries to optimize the provided graph using grappler.

  Args:
    graph: A @{tf.Graph} instance containing the graph to optimize.
    fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away).
      Grappler uses the 'train_op' collection to look for fetches, so if not
      provided this collection should be non-empty.
    rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the
      graph.

  Returns:
    A @{tf.GraphDef} containing the rewritten graph.
  """
  if rewriter_config is None:
    rewriter_config = rewriter_config_pb2.RewriterConfig()
  if fetches is not None:
    for fetch in fetches:
      graph.add_to_collection('train_op', fetch)
  metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
  return tf_optimizer.OptimizeGraph(rewriter_config, metagraph)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:21,代码来源:test_util.py


示例15: _convert_graph_def

  def _convert_graph_def(self):
    """Convert the input GraphDef."""
    graph = ops.Graph()
    with graph.as_default():
      importer.import_graph_def(self._input_graph_def, name="")
    self._grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
    if self._nodes_blacklist:
      output_collection = meta_graph_pb2.CollectionDef()
      output_list = output_collection.node_list.value
      for i in self._nodes_blacklist:
        if isinstance(i, ops.Tensor):
          output_list.append(_to_bytes(i.name))
        else:
          output_list.append(_to_bytes(i))
      # TODO(laigd): use another key as the self._nodes_blacklist are really
      # not train_op.
      self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
          output_collection)

    self._run_conversion()
开发者ID:kylin9872,项目名称:tensorflow,代码行数:21,代码来源:trt_convert.py


示例16: _test_convert_variables_with_functions

  def _test_convert_variables_with_functions(self, inline_functions):
    """Freezes a graph with functions."""

    @function.Defun(dtypes.float32)
    def plus_one(x):
      return x + 1.0

    with ops.Graph().as_default():
      variable_node = variables.Variable(1.0, name="variable_node")
      _ = variables.Variable(1.0, name="unused_variable_node")
      defun_node = plus_one(variable_node)
      _ = math_ops_lib.multiply(defun_node, 2.0, name="output_node")

      with session.Session() as sess:
        self.evaluate(variables.variables_initializer([variable_node]))
        variable_graph_def = sess.graph.as_graph_def()

        if inline_functions:
          # Run Grappler to create the VarOpHandle --> Placeholder -->
          # ResourceVariable pattern.
          meta_graph = export_meta_graph(graph_def=variable_graph_def)
          fetch_collection = meta_graph_pb2.CollectionDef()
          for name in ["variable_node", "output_node"]:
            fetch_collection.node_list.value.append(name)
          meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)

          # Initialize RewriterConfig with everything disabled except function
          # inlining.
          config = config_pb2.ConfigProto()
          rewrite_options = config.graph_options.rewrite_options
          rewrite_options.optimizers.append("function")
          variable_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph)

        constant_graph_def = graph_util.convert_variables_to_constants(
            sess, variable_graph_def, ["output_node"])

    # Ensure there are no variables after freezing.
    for node in constant_graph_def.node:
      self.assertNotIn(
          node.op, ["Variable", "VariableV2", "VarHandleOp", "ReadVariableOp"])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:40,代码来源:graph_util_test.py


示例17: testGradientOfDeserializedCond

  def testGradientOfDeserializedCond(self):
    with ops.Graph().as_default():
      pred = array_ops.placeholder(dtypes.bool, name="pred")
      x = constant_op.constant(3.0, name="x")
      ops.add_to_collection("x", x)

      def true_fn():
        return math_ops.pow(x, 3)

      def false_fn():
        return x

      ops.add_to_collection("pred", pred)
      cond = cond_v2.cond_v2(pred, true_fn, false_fn, name="cond")
      for c in cond:
        ops.add_to_collection("cond", c)
      meta_graph = saver.export_meta_graph()

    with ops.Graph().as_default() as g:
      with self.test_session(graph=g) as sess:
        saver.import_meta_graph(meta_graph)
        x = ops.get_collection("x")[0]
        pred = ops.get_collection("pred")[0]
        cond = ops.get_collection("cond")
        cond_grad = gradients_impl.gradients(cond, [x], name="cond_grad")
        cond_grad_grad = gradients_impl.gradients(
            cond_grad, [x], name="cond_grad_grad")
        # d[x^3]/dx = 3x^2
        true_val = sess.run(cond_grad, {pred: True})
        self.assertEqual(true_val, [27.0])
        # d[x]/dx = 1
        false_val = sess.run(cond_grad, {pred: False})
        self.assertEqual(false_val, [1.0])

        true_val = sess.run(cond_grad_grad, {pred: True})
        # d2[x^3]/dx2 = 6x
        self.assertEqual(true_val, [18.0])
        false_val = sess.run(cond_grad_grad, {pred: False})
        # d2[x]/dx2 = 0
        self.assertEqual(false_val, [0.0])
开发者ID:clsung,项目名称:tensorflow,代码行数:40,代码来源:cond_v2_test.py


示例18: grappler_optimize

def grappler_optimize(graph, fetches=None, config_proto=None):
  """Tries to optimize the provided graph using grappler.

  Args:
    graph: A `tf.Graph` instance containing the graph to optimize.
    fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away).
      Grappler uses the 'train_op' collection to look for fetches, so if not
      provided this collection should be non-empty.
    config_proto: An optional `tf.ConfigProto` to use when rewriting the
      graph.

  Returns:
    A `tf.GraphDef` containing the rewritten graph.
  """
  if config_proto is None:
    config_proto = config_pb2.ConfigProto()
    config_proto.graph_options.rewrite_options.min_graph_nodes = -1
  if fetches is not None:
    for fetch in fetches:
      graph.add_to_collection('train_op', fetch)
  metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def())
  return tf_optimizer.OptimizeGraph(config_proto, metagraph)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:22,代码来源:test_util.py


示例19: get_metagraph

def get_metagraph():
  """Constructs and returns a MetaGraphDef from the input file."""
  with gfile.GFile(FLAGS.input) as input_file:
    input_data = input_file.read()
    try:
      saved_model = saved_model_pb2.SavedModel()
      text_format.Merge(input_data, saved_model)
      meta_graph = saved_model.meta_graphs[0]
    except text_format.ParseError:
      try:
        saved_model.ParseFromString(input_data)
        meta_graph = saved_model.meta_graphs[0]
      except message.DecodeError:
        try:
          meta_graph = meta_graph_pb2.MetaGraphDef()
          text_format.Merge(input_data, meta_graph)
        except text_format.ParseError:
          try:
            meta_graph.ParseFromString(input_data)
          except message.DecodeError:
            try:
              graph_def = graph_pb2.GraphDef()
              text_format.Merge(input_data, graph_def)
            except text_format.ParseError:
              try:
                graph_def.ParseFromString(input_data)
              except message.DecodeError:
                raise ValueError("Invalid input file.")
            importer.import_graph_def(graph_def, name="")
            graph = ops.get_default_graph()
            meta_graph = saver.export_meta_graph(
                graph_def=graph.as_graph_def(), graph=graph)
  if FLAGS.fetch is not None:
    fetch_collection = meta_graph_pb2.CollectionDef()
    for fetch in FLAGS.fetch.split(","):
      fetch_collection.node_list.value.append(fetch)
    meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
  return meta_graph
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:38,代码来源:cost_analyzer_tool.py


示例20: convert

  def convert(self):
    """Convert the input SavedModel in 2.0 format.

    Returns:
      The TF-TRT converted Function.
    """
    assert not self._converted
    self._saved_model = load.load(self._input_saved_model_dir,
                                  self._input_saved_model_tags)
    func = self._saved_model.signatures[self._input_saved_model_signature_key]
    frozen_func = convert_to_constants.convert_variables_to_constants_v2(func)
    grappler_meta_graph_def = saver.export_meta_graph(
        graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph)

    # Add a collection 'train_op' so that Grappler knows the outputs.
    fetch_collection = meta_graph_pb2.CollectionDef()
    for array in frozen_func.inputs + frozen_func.outputs:
      fetch_collection.node_list.value.append(array.name)
    grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
        fetch_collection)

    # Run TRT optimizer in Grappler to convert the graph.
    self._converted_graph_def = self._run_conversion(grappler_meta_graph_def)
    self._converted_func = wrap_function.function_from_graph_def(
        self._converted_graph_def,
        [tensor.name for tensor in frozen_func.inputs],
        [tensor.name for tensor in frozen_func.outputs])

    self._converted = True

    # Wrap the converted ConcreteFunction in a Function so it can accept numpy
    # arrays as input.
    @def_function.function
    def wrapper_func(*args, **kwargs):
      return self._converted_func(*args, **kwargs)

    return wrapper_func
开发者ID:aritratony,项目名称:tensorflow,代码行数:37,代码来源:trt_convert.py



注:本文中的tensorflow.python.training.saver.export_meta_graph函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python saver.import_meta_graph函数代码示例发布时间:2022-05-27
下一篇:
Python queue_runner_impl.start_queue_runners函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap