• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python meta_graph.create_meta_graph_def函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.meta_graph.create_meta_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python create_meta_graph_def函数的具体用法?Python create_meta_graph_def怎么用?Python create_meta_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了create_meta_graph_def函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testSmallNetwork

  def testSmallNetwork(self):
    image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
    label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
    w = variables.Variable(
        random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1))
    b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1))
    conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME")
    h_conv = nn_ops.relu(conv + b)
    h_conv_flat = array_ops.reshape(h_conv, [1, -1])

    w_fc = variables.Variable(
        random_ops.truncated_normal([25088, 10], stddev=0.1))
    b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1))
    y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc)

    cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum(
        label * math_ops.log(y_conv), reduction_indices=[1]))
    _ = adam.AdamOptimizer(1e-4).minimize(cross_entropy)

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
    report = cost_analyzer.GenerateCostReport(mg)

    self.assertTrue(b"MatMul" in report)
    self.assertTrue(b"ApplyAdam" in report)
    self.assertTrue(b"Conv2D" in report)
    self.assertTrue(b"Conv2DBackpropInput" in report)
    self.assertTrue(b"Conv2DBackpropFilter" in report)
    self.assertTrue(b"Softmax" in report)

    # Also print the report to make it easier to debug
    print("{}".format(report))
开发者ID:ajaybhat,项目名称:tensorflow,代码行数:31,代码来源:cost_analyzer_test.py


示例2: testFromStringHandle

  def testFromStringHandle(self):
    test_cases = [{
        'shape': tensor_shape.TensorShape([])
    }, {
        'shape': tensor_shape.TensorShape([3])
    }, {
        'shape': tensor_shape.TensorShape([1, 2])
    }, {
        'shape': tensor_shape.TensorShape([1, 2, 3])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
        handle = iterator.string_handle()
        iterator = iterator_ops.Iterator.from_string_handle(
            handle, dtypes.int64, output_shapes=test_case['shape'])
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:25,代码来源:datasets_test.py


示例3: before_run

    def before_run(self, run_context):
        """ Dumps graphs and loads checkpoint if there exits.

        Called before each call to run().

        Args:
            run_context: A `SessionRunContext` object.

        Returns: A `SessionRunArgs` object containing global_step.
        """
        # We do write graph and saver_def at the first call of before_run.
        # We cannot do this in begin, since we let other hooks to change graph and
        # add variables in begin. Graph is finalized after all begin calls.
        if self._is_chief and self._first_call:
            training_util.write_graph(
                ops.get_default_graph().as_graph_def(add_shapes=True),
                self._checkpoint_dir,
                "graph.pbtxt")
            # dump model details "model_analysis.txt"
            dump_model_analysis(self._checkpoint_dir)  # dump model configs
            graph = ops.get_default_graph()
            meta_graph_def = meta_graph.create_meta_graph_def(
                graph_def=graph.as_graph_def(add_shapes=True),
                saver_def=self._saver.saver_def)
            if self._summary_writer is not None:
                self._summary_writer.add_graph(graph)
                self._summary_writer.add_meta_graph(meta_graph_def)
            tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
        self._first_call = False
        return tf.train.SessionRunArgs(self._global_step)
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:30,代码来源:hooks.py


示例4: testUpdates

  def testUpdates(self):
    with ops.Graph().as_default() as g:
      a = constant_op.constant(10)
      b = constant_op.constant(20)
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)

    initial_tf_item = grappler_item.tf_item
    no_change_tf_item = grappler_item.tf_item
    self.assertEqual(initial_tf_item, no_change_tf_item)

    # Modify the placement.
    for node in grappler_item.metagraph.graph_def.node:
      node.device = '/cpu:0'
    new_tf_item = grappler_item.tf_item
    self.assertNotEqual(initial_tf_item, new_tf_item)

    # Assign the same placement.
    for node in grappler_item.metagraph.graph_def.node:
      node.device = '/cpu:0'
    newest_tf_item = grappler_item.tf_item
    self.assertEqual(new_tf_item, newest_tf_item)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:25,代码来源:item_test.py


示例5: testMap

  def testMap(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([3, 1])
    }, {
        'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
        'shape': tensor_shape.TensorShape([3, 2, 1])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
        dataset = dataset.map(array_ops.transpose)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py


示例6: testKeepNodes

  def testKeepNodes(self):
    g = ops.Graph()
    with g.as_default():
      a1 = variables.VariableV1(
          1.0)  # Must be preserved since it's in the collection 'variables'.
      a2 = constant_op.constant(0, shape=[50, 50], name='keep')
      ops.add_to_collection('a2', a2)  # Explicitly add to collection.
      b = constant_op.constant(1, shape=[100, 10])
      c = constant_op.constant(0, shape=[10, 30])
      d = math_ops.matmul(b, c)
      ops.add_to_collection('train_op', d)  # d is the fetch node.

    # Optimize the graph.
    mg = meta_graph.create_meta_graph_def(graph=g)
    rewriter_config = rewriter_config_pb2.RewriterConfig()
    rewriter_config.min_graph_nodes = -1
    optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    # Check that the nodes referenced in various collections have been preserved
    self.assertEqual(len(optimized_graph.node), 5)
    self.assertEqual(d.op.name, optimized_graph.node[0].name)
    self.assertEqual(a1.op.name, optimized_graph.node[1].name)
    self.assertEqual('Variable/initial_value', optimized_graph.node[2].name)
    self.assertEqual(a2.op.name, optimized_graph.node[3].name)
    self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:25,代码来源:tf_optimizer_test.py


示例7: testPaddedBatch

  def testPaddedBatch(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([None])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([None, 4])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([None, 2, 4])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
        dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:])
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        inferred_shape = self.as_tensor_shape(
            op_properties['IteratorGetNext'][0].shape)
        self.assertTrue(test_case['shape'].dims[0].is_compatible_with(
            inferred_shape[0]))
        self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py


示例8: testInterleave

  def testInterleave(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:
      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.range(42)

        def make_dataset(tensor):

          def dataset_fn(n):
            return dataset_ops.Dataset.from_tensors(tensor).repeat(n)

          return dataset_fn

        dataset = dataset.interleave(
            make_dataset(test_case['tensor']), cycle_length=42)
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:34,代码来源:datasets_test.py


示例9: testKeepNodes

  def testKeepNodes(self):
    g = ops.Graph()
    with g.as_default():
      a1 = variables.VariableV1(
          1.0)  # Must be preserved since it's in the collection 'variables'.
      a2 = constant_op.constant(0, shape=[50, 50], name='keep')
      ops.add_to_collection('a2', a2)  # Explicitly add to collection.
      with g._attr_scope(
          {'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
        a3 = constant_op.constant(0, name='keep2')
      b = constant_op.constant(1, shape=[100, 10])
      c = constant_op.constant(0, shape=[10, 30])
      d = math_ops.matmul(b, c)
      ops.add_to_collection('train_op', d)  # d is the fetch node.

    # Optimize the graph.
    mg = meta_graph.create_meta_graph_def(graph=g)
    config = config_pb2.ConfigProto()
    rewriter_config = config.graph_options.rewrite_options
    rewriter_config.min_graph_nodes = -1
    optimized_graph = tf_optimizer.OptimizeGraph(config, mg)

    # Check that the nodes referenced in various collections have been preserved
    optimized_graph_nodes = [node.name for node in optimized_graph.node]
    expected_nodes = [
        d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
        'Variable/Assign'
    ]
    self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))
    self.assertAllInSet(optimized_graph_nodes, expected_nodes)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:30,代码来源:tf_optimizer_test.py


示例10: testSimpleSwap

  def testSimpleSwap(self):
    """Check that the swap annotations are followed."""
    a = constant_op.constant(10, name='a')
    b = constant_op.constant(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)

    d.op.node_def.attr['_swap_to_host'].i = 0

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), 6)
    self.assertItemsEqual([node.name for node in graph.node], [
        'a',
        'b',
        'c',
        'd',
        'swap_in_d_0',
        'swap_out_d_0',
    ])
    for node in graph.node:
      if node.name == 'swap_in_d_0':
        self.assertEqual('swap_out_d_0', node.input[0])
        self.assertEqual('^b', node.input[1])
      elif node.name == 'swap_out_d_0':
        self.assertEqual('b', node.input[0])
      elif node.name == 'd':
        self.assertEqual('swap_in_d_0', node.input[0])
        self.assertEqual('c', node.input[1])
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:35,代码来源:memory_optimizer_test.py


示例11: _assertEventsWithGraph

  def _assertEventsWithGraph(self, test_dir, g, has_shapes):
    meta_graph_def = meta_graph.create_meta_graph_def(
        graph_def=g.as_graph_def(add_shapes=has_shapes))

    rr = self._EventsReader(test_dir)

    # The first event should list the file_version.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals("brain.Event:2", ev.file_version)

    # The next event should have the graph.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals(0, ev.step)
    ev_graph = graph_pb2.GraphDef()
    ev_graph.ParseFromString(ev.graph_def)
    self.assertProtoEquals(g.as_graph_def(add_shapes=has_shapes), ev_graph)

    # The next event should have the metagraph.
    ev = next(rr)
    self._assertRecent(ev.wall_time)
    self.assertEquals(0, ev.step)
    ev_meta_graph = meta_graph_pb2.MetaGraphDef()
    ev_meta_graph.ParseFromString(ev.meta_graph_def)
    self.assertProtoEquals(meta_graph_def, ev_meta_graph)

    # We should be done.
    self.assertRaises(StopIteration, lambda: next(rr))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:29,代码来源:writer_test.py


示例12: testSimpleSwap

  def testSimpleSwap(self):
    """Check that the swap annotations are followed."""
    a = variables.Variable(10, name='a')
    b = variables.Variable(20, name='b')
    c = math_ops.add_n([a, b], name='c')
    d = math_ops.add_n([b, c], name='d')
    train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
    train_op.append(d)

    d.op.node_def.attr['_swap_to_host'].i = 0

    mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
    graph_size = len(mg.graph_def.node)

    rewriter_config = rewriter_config_pb2.RewriterConfig(
        disable_model_pruning=True,
        memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
    graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)

    self.assertEqual(len(graph.node), graph_size + 2)
    self.assertTrue(
        set([node.name for node in graph.node]) > set(
            ['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
    for node in graph.node:
      if node.name == 'swap_in_d_0':
        self.assertEqual('swap_out_d_0', node.input[0])
        self.assertEqual('^b/read', node.input[1])
      elif node.name == 'swap_out_d_0':
        self.assertEqual('b/read', node.input[0])
      elif node.name == 'd':
        self.assertEqual('swap_in_d_0', node.input[0])
        self.assertEqual('c', node.input[1])
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:32,代码来源:memory_optimizer_test.py


示例13: testBasicMemory

  def testBasicMemory(self):
    """Make sure arguments can be passed correctly."""
    with test_util.device(use_gpu=False):
      a = constant_op.constant(10, name="a")
      b = constant_op.constant(20, name="b")
      c = math_ops.add_n([a, b], name="c")
      d = math_ops.add_n([b, c], name="d")
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(d)
      mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())

    report = cost_analyzer.GenerateMemoryReport(mg)

    # Print the report to make it easier to debug
    print("{}".format(report))

    # Check the report
    self.assertTrue(
        "Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
        "16 bytes"
        in report)
    self.assertTrue("  a:0 uses 4 bytes" in report)
    self.assertTrue("  b:0 uses 4 bytes" in report)
    self.assertTrue("  c:0 uses 4 bytes" in report)
    self.assertTrue("  d:0 uses 4 bytes" in report)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:25,代码来源:cost_analyzer_test.py


示例14: test_train_summaries

 def test_train_summaries(self):
   with ops.Graph().as_default() as g, self.test_session(g):
     with ops.control_dependencies(self._build_inference_graph()):
       train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
     loss_op = constant_op.constant(2.0)
     summary.scalar('loss', loss_op)
     writer = learn.graph_actions.get_summary_writer(self._output_dir)
     self._assert_summaries(self._output_dir, writer)
     self._assert_ckpt(self._output_dir, False)
     loss = learn.graph_actions._monitored_train(  # pylint: disable=protected-access
         g,
         output_dir=self._output_dir,
         train_op=train_op,
         loss_op=loss_op,
         steps=1)
     meta_graph_def = meta_graph.create_meta_graph_def(
         graph_def=g.as_graph_def(add_shapes=True),
         saver_def=monitored_session.Scaffold().finalize().saver.saver_def)
     self.assertEqual(2.0, loss)
     self._assert_summaries(
         self._output_dir,
         writer,
         expected_graphs=[g],
         expected_meta_graphs=[meta_graph_def],
         expected_summaries={1: {
             'loss': 2.0
         }})
     self._assert_ckpt(self._output_dir, True)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:28,代码来源:graph_actions_test.py


示例15: testStandardServicesWithGlobalStep

 def testStandardServicesWithGlobalStep(self):
   logdir = _test_dir("standard_services_with_global_step")
   # Create a checkpoint.
   with tf.Graph().as_default():
     v = tf.Variable([123], name="global_step")
     sv = tf.train.Supervisor(logdir=logdir)
     meta_graph_def = meta_graph.create_meta_graph_def(
         saver_def=sv.saver.saver_def)
     sess = sv.prepare_or_wait_for_session("")
     # This is where the checkpoint will appear, with step number 123.
     save_path = "%s-123" % sv.save_path
     self._wait_for_glob(save_path, 3.0)
     self._wait_for_glob(os.path.join(logdir, "*events*"), 3.0)
     # Wait to make sure everything is written to file before stopping.
     time.sleep(1)
     sv.stop()
   # There should be an event file with a version number.
   rr = _summary_iterator(logdir)
   ev = next(rr)
   self.assertEquals("brain.Event:2", ev.file_version)
   ev = next(rr)
   ev_graph = tf.GraphDef()
   ev_graph.ParseFromString(ev.graph_def)
   self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
   ev = next(rr)
   ev_meta_graph = meta_graph_pb2.MetaGraphDef()
   ev_meta_graph.ParseFromString(ev.meta_graph_def)
   self.assertProtoEquals(meta_graph_def, ev_meta_graph)
   self.assertProtoEquals(
       sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
   ev = next(rr)
   # It is actually undeterministic whether SessionLog.START gets written
   # before the summary or the checkpoint, but this works when run 10000 times.
   self.assertEquals(123, ev.step)
   self.assertEquals(tf.SessionLog.START, ev.session_log.status)
   first = next(rr)
   second = next(rr)
   # It is undeterministic whether the value gets written before the checkpoint
   # since they are on separate threads, so we check for both conditions.
   if first.HasField("summary"):
     self.assertProtoEquals("""value { tag: 'global_step/sec'
                                       simple_value: 0.0 }""",
                            first.summary)
     self.assertEquals(123, second.step)
     self.assertEquals(tf.SessionLog.CHECKPOINT, second.session_log.status)
   else:
     self.assertEquals(123, first.step)
     self.assertEquals(tf.SessionLog.CHECKPOINT, first.session_log.status)
     self.assertProtoEquals("""value { tag: 'global_step/sec'
                                       simple_value: 0.0 }""",
                            second.summary)
   ev = next(rr)
   self.assertEquals(tf.SessionLog.STOP, ev.session_log.status)
   self.assertRaises(StopIteration, lambda: next(rr))
   # There should be a checkpoint file with the variable "foo"
   with tf.Graph().as_default(), self.test_session() as sess:
     v = tf.Variable([-12], name="global_step")
     sav = tf.train.Saver([v])
     sav.restore(sess, save_path)
     self.assertEqual(123, v.eval()[0])
开发者ID:KalraA,项目名称:tensorflow,代码行数:60,代码来源:supervisor_test.py


示例16: testVirtualCluster

  def testVirtualCluster(self):
    with ops.Graph().as_default() as g:
      a = random_ops.random_uniform(shape=())
      b = random_ops.random_uniform(shape=())
      c = a + b
      train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
      train_op.append(c)
      mg = meta_graph.create_meta_graph_def(graph=g)
      grappler_item = item.Item(mg)
      device_properties = device_properties_pb2.DeviceProperties(
          type='GPU',
          frequency=1000,
          num_cores=60,
          environment={
              'architecture': '7'
          })
      named_device = device_properties_pb2.NamedDevice(
          properties=device_properties, name='/GPU:0')
      grappler_cluster = cluster.Cluster(devices=[named_device])
      op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
      self.assertGreater(run_time, 0)
      self.assertEqual(len(op_perfs), 15)

      estimated_perf = grappler_cluster.EstimatePerformance(named_device)
      self.assertEqual(7680.0, estimated_perf)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:cluster_test.py


示例17: testFromGenerator

  def testFromGenerator(self):
    test_cases = [{
        'tensor': 0,
        'shape': tensor_shape.TensorShape([])
    }, {
        'tensor': np.array([1, 2, 3]),
        'shape': tensor_shape.TensorShape([3])
    }, {
        'tensor': np.array([[1, 2, 3]]),
        'shape': tensor_shape.TensorShape([1, 3])
    }]

    for test_case in test_cases:

      def make_generator(tensor):

        def generator():
          yield tensor

        return generator

      with ops.Graph().as_default() as g:
        dataset = dataset_ops.Dataset.from_generator(
            make_generator(test_case['tensor']),
            dtypes.int64,
            output_shapes=test_case['shape'])
        iterator = dataset.make_one_shot_iterator()
        get_next = iterator.get_next()
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        train_op.append(get_next)
        mg = meta_graph.create_meta_graph_def(graph=g)
        grappler_item = item.Item(mg)
        op_properties = grappler_item.GetOpProperties()
        self.assertEqual(test_case['shape'],
                         op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:35,代码来源:datasets_test.py


示例18: GetOptimizedGraph

 def GetOptimizedGraph():
   mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
   config = config_pb2.ConfigProto()
   config.graph_options.rewrite_options.CopyFrom(
       rewriter_config_pb2.RewriterConfig(
           constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
           memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
   return tf_optimizer.OptimizeGraph(config, mg)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:8,代码来源:while_v2_test.py


示例19: testImportantOps

 def testImportantOps(self):
   with ops.Graph().as_default() as g:
     a = constant_op.constant(10)
     b = constant_op.constant(20)
     c = a + b
     train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
     train_op.append(c)
     mg = meta_graph.create_meta_graph_def(graph=g)
     grappler_item = item.Item(mg)
     op_list = grappler_item.IdentifyImportantOps()
     self.assertEqual([b'Const', b'Const_1', b'add'], op_list)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:11,代码来源:item_test.py


示例20: testInvalidItem

  def testInvalidItem(self):
    with ops.Graph().as_default() as g:
      a = constant_op.constant(10)
      b = constant_op.constant(20)
      c = a + b  # pylint: disable=unused-variable
      mg = meta_graph.create_meta_graph_def(graph=g)

    # The train op isn't specified: this should raise an InvalidArgumentError
    # exception.
    with self.assertRaises(errors_impl.InvalidArgumentError):
      item.Item(mg)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:11,代码来源:item_test.py



注:本文中的tensorflow.python.framework.meta_graph.create_meta_graph_def函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python meta_graph.export_scoped_meta_graph函数代码示例发布时间:2022-05-27
下一篇:
Python load_library.load_op_library函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap