本文整理汇总了Python中tensorflow.python.framework.meta_graph.create_meta_graph_def函数的典型用法代码示例。如果您正苦于以下问题:Python create_meta_graph_def函数的具体用法?Python create_meta_graph_def怎么用?Python create_meta_graph_def使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了create_meta_graph_def函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testSmallNetwork
def testSmallNetwork(self):
image = array_ops.placeholder(dtypes.float32, shape=[1, 28, 28, 1])
label = array_ops.placeholder(dtypes.float32, shape=[1, 10])
w = variables.Variable(
random_ops.truncated_normal([5, 5, 1, 32], stddev=0.1))
b = variables.Variable(random_ops.truncated_normal([32], stddev=0.1))
conv = nn_ops.conv2d(image, w, strides=[1, 1, 1, 1], padding="SAME")
h_conv = nn_ops.relu(conv + b)
h_conv_flat = array_ops.reshape(h_conv, [1, -1])
w_fc = variables.Variable(
random_ops.truncated_normal([25088, 10], stddev=0.1))
b_fc = variables.Variable(random_ops.truncated_normal([10], stddev=0.1))
y_conv = nn_ops.softmax(math_ops.matmul(h_conv_flat, w_fc) + b_fc)
cross_entropy = math_ops.reduce_mean(-math_ops.reduce_sum(
label * math_ops.log(y_conv), reduction_indices=[1]))
_ = adam.AdamOptimizer(1e-4).minimize(cross_entropy)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateCostReport(mg)
self.assertTrue(b"MatMul" in report)
self.assertTrue(b"ApplyAdam" in report)
self.assertTrue(b"Conv2D" in report)
self.assertTrue(b"Conv2DBackpropInput" in report)
self.assertTrue(b"Conv2DBackpropFilter" in report)
self.assertTrue(b"Softmax" in report)
# Also print the report to make it easier to debug
print("{}".format(report))
开发者ID:ajaybhat,项目名称:tensorflow,代码行数:31,代码来源:cost_analyzer_test.py
示例2: testFromStringHandle
def testFromStringHandle(self):
test_cases = [{
'shape': tensor_shape.TensorShape([])
}, {
'shape': tensor_shape.TensorShape([3])
}, {
'shape': tensor_shape.TensorShape([1, 2])
}, {
'shape': tensor_shape.TensorShape([1, 2, 3])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
iterator = iterator_ops.Iterator.from_structure(dtypes.int64)
handle = iterator.string_handle()
iterator = iterator_ops.Iterator.from_string_handle(
handle, dtypes.int64, output_shapes=test_case['shape'])
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:25,代码来源:datasets_test.py
示例3: before_run
def before_run(self, run_context):
""" Dumps graphs and loads checkpoint if there exits.
Called before each call to run().
Args:
run_context: A `SessionRunContext` object.
Returns: A `SessionRunArgs` object containing global_step.
"""
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
if self._is_chief and self._first_call:
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir,
"graph.pbtxt")
# dump model details "model_analysis.txt"
dump_model_analysis(self._checkpoint_dir) # dump model configs
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=self._saver.saver_def)
if self._summary_writer is not None:
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
tf.logging.info("CheckpointSaverHook (before_run): dump graph...")
self._first_call = False
return tf.train.SessionRunArgs(self._global_step)
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:30,代码来源:hooks.py
示例4: testUpdates
def testUpdates(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
initial_tf_item = grappler_item.tf_item
no_change_tf_item = grappler_item.tf_item
self.assertEqual(initial_tf_item, no_change_tf_item)
# Modify the placement.
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
new_tf_item = grappler_item.tf_item
self.assertNotEqual(initial_tf_item, new_tf_item)
# Assign the same placement.
for node in grappler_item.metagraph.graph_def.node:
node.device = '/cpu:0'
newest_tf_item = grappler_item.tf_item
self.assertEqual(new_tf_item, newest_tf_item)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:25,代码来源:item_test.py
示例5: testMap
def testMap(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([3])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([3, 1])
}, {
'tensor': np.array([[[1, 2, 3], [4, 5, 6]]]),
'shape': tensor_shape.TensorShape([3, 2, 1])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
dataset = dataset.map(array_ops.transpose)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py
示例6: testKeepNodes
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
b = constant_op.constant(1, shape=[100, 10])
c = constant_op.constant(0, shape=[10, 30])
d = math_ops.matmul(b, c)
ops.add_to_collection('train_op', d) # d is the fetch node.
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
rewriter_config = rewriter_config_pb2.RewriterConfig()
rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
# Check that the nodes referenced in various collections have been preserved
self.assertEqual(len(optimized_graph.node), 5)
self.assertEqual(d.op.name, optimized_graph.node[0].name)
self.assertEqual(a1.op.name, optimized_graph.node[1].name)
self.assertEqual('Variable/initial_value', optimized_graph.node[2].name)
self.assertEqual(a2.op.name, optimized_graph.node[3].name)
self.assertEqual('Variable/Assign', optimized_graph.node[4].name)
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:25,代码来源:tf_optimizer_test.py
示例7: testPaddedBatch
def testPaddedBatch(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([None])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([None, 4])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([None, 2, 4])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.from_tensors(test_case['tensor'])
dataset = dataset.padded_batch(42, padded_shapes=test_case['shape'][1:])
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
inferred_shape = self.as_tensor_shape(
op_properties['IteratorGetNext'][0].shape)
self.assertTrue(test_case['shape'].dims[0].is_compatible_with(
inferred_shape[0]))
self.assertEqual(test_case['shape'][1:], inferred_shape[1:])
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:28,代码来源:datasets_test.py
示例8: testInterleave
def testInterleave(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([3])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([1, 3])
}]
for test_case in test_cases:
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.range(42)
def make_dataset(tensor):
def dataset_fn(n):
return dataset_ops.Dataset.from_tensors(tensor).repeat(n)
return dataset_fn
dataset = dataset.interleave(
make_dataset(test_case['tensor']), cycle_length=42)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:34,代码来源:datasets_test.py
示例9: testKeepNodes
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
with g._attr_scope(
{'_grappler_do_not_remove': attr_value_pb2.AttrValue(b=True)}):
a3 = constant_op.constant(0, name='keep2')
b = constant_op.constant(1, shape=[100, 10])
c = constant_op.constant(0, shape=[10, 30])
d = math_ops.matmul(b, c)
ops.add_to_collection('train_op', d) # d is the fetch node.
# Optimize the graph.
mg = meta_graph.create_meta_graph_def(graph=g)
config = config_pb2.ConfigProto()
rewriter_config = config.graph_options.rewrite_options
rewriter_config.min_graph_nodes = -1
optimized_graph = tf_optimizer.OptimizeGraph(config, mg)
# Check that the nodes referenced in various collections have been preserved
optimized_graph_nodes = [node.name for node in optimized_graph.node]
expected_nodes = [
d.op.name, a1.op.name, a2.op.name, a3.op.name, 'Variable/initial_value',
'Variable/Assign'
]
self.assertEqual(len(optimized_graph_nodes), len(expected_nodes))
self.assertAllInSet(optimized_graph_nodes, expected_nodes)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:30,代码来源:tf_optimizer_test.py
示例10: testSimpleSwap
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
a = constant_op.constant(10, name='a')
b = constant_op.constant(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
d.op.node_def.attr['_swap_to_host'].i = 0
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
rewriter_config = rewriter_config_pb2.RewriterConfig(
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), 6)
self.assertItemsEqual([node.name for node in graph.node], [
'a',
'b',
'c',
'd',
'swap_in_d_0',
'swap_out_d_0',
])
for node in graph.node:
if node.name == 'swap_in_d_0':
self.assertEqual('swap_out_d_0', node.input[0])
self.assertEqual('^b', node.input[1])
elif node.name == 'swap_out_d_0':
self.assertEqual('b', node.input[0])
elif node.name == 'd':
self.assertEqual('swap_in_d_0', node.input[0])
self.assertEqual('c', node.input[1])
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:35,代码来源:memory_optimizer_test.py
示例11: _assertEventsWithGraph
def _assertEventsWithGraph(self, test_dir, g, has_shapes):
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=g.as_graph_def(add_shapes=has_shapes))
rr = self._EventsReader(test_dir)
# The first event should list the file_version.
ev = next(rr)
self._assertRecent(ev.wall_time)
self.assertEquals("brain.Event:2", ev.file_version)
# The next event should have the graph.
ev = next(rr)
self._assertRecent(ev.wall_time)
self.assertEquals(0, ev.step)
ev_graph = graph_pb2.GraphDef()
ev_graph.ParseFromString(ev.graph_def)
self.assertProtoEquals(g.as_graph_def(add_shapes=has_shapes), ev_graph)
# The next event should have the metagraph.
ev = next(rr)
self._assertRecent(ev.wall_time)
self.assertEquals(0, ev.step)
ev_meta_graph = meta_graph_pb2.MetaGraphDef()
ev_meta_graph.ParseFromString(ev.meta_graph_def)
self.assertProtoEquals(meta_graph_def, ev_meta_graph)
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:29,代码来源:writer_test.py
示例12: testSimpleSwap
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
a = variables.Variable(10, name='a')
b = variables.Variable(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
d.op.node_def.attr['_swap_to_host'].i = 0
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
graph_size = len(mg.graph_def.node)
rewriter_config = rewriter_config_pb2.RewriterConfig(
disable_model_pruning=True,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL)
graph = tf_optimizer.OptimizeGraph(rewriter_config, mg)
self.assertEqual(len(graph.node), graph_size + 2)
self.assertTrue(
set([node.name for node in graph.node]) > set(
['a', 'b', 'c', 'd', 'swap_in_d_0', 'swap_out_d_0']))
for node in graph.node:
if node.name == 'swap_in_d_0':
self.assertEqual('swap_out_d_0', node.input[0])
self.assertEqual('^b/read', node.input[1])
elif node.name == 'swap_out_d_0':
self.assertEqual('b/read', node.input[0])
elif node.name == 'd':
self.assertEqual('swap_in_d_0', node.input[0])
self.assertEqual('c', node.input[1])
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:32,代码来源:memory_optimizer_test.py
示例13: testBasicMemory
def testBasicMemory(self):
"""Make sure arguments can be passed correctly."""
with test_util.device(use_gpu=False):
a = constant_op.constant(10, name="a")
b = constant_op.constant(20, name="b")
c = math_ops.add_n([a, b], name="c")
d = math_ops.add_n([b, c], name="d")
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(d)
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
report = cost_analyzer.GenerateMemoryReport(mg)
# Print the report to make it easier to debug
print("{}".format(report))
# Check the report
self.assertTrue(
"Peak usage for device /job:localhost/replica:0/task:0/device:CPU:0: "
"16 bytes"
in report)
self.assertTrue(" a:0 uses 4 bytes" in report)
self.assertTrue(" b:0 uses 4 bytes" in report)
self.assertTrue(" c:0 uses 4 bytes" in report)
self.assertTrue(" d:0 uses 4 bytes" in report)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:25,代码来源:cost_analyzer_test.py
示例14: test_train_summaries
def test_train_summaries(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
summary.scalar('loss', loss_op)
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_op,
steps=1)
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=g.as_graph_def(add_shapes=True),
saver_def=monitored_session.Scaffold().finalize().saver.saver_def)
self.assertEqual(2.0, loss)
self._assert_summaries(
self._output_dir,
writer,
expected_graphs=[g],
expected_meta_graphs=[meta_graph_def],
expected_summaries={1: {
'loss': 2.0
}})
self._assert_ckpt(self._output_dir, True)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:28,代码来源:graph_actions_test.py
示例15: testStandardServicesWithGlobalStep
def testStandardServicesWithGlobalStep(self):
logdir = _test_dir("standard_services_with_global_step")
# Create a checkpoint.
with tf.Graph().as_default():
v = tf.Variable([123], name="global_step")
sv = tf.train.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
saver_def=sv.saver.saver_def)
sess = sv.prepare_or_wait_for_session("")
# This is where the checkpoint will appear, with step number 123.
save_path = "%s-123" % sv.save_path
self._wait_for_glob(save_path, 3.0)
self._wait_for_glob(os.path.join(logdir, "*events*"), 3.0)
# Wait to make sure everything is written to file before stopping.
time.sleep(1)
sv.stop()
# There should be an event file with a version number.
rr = _summary_iterator(logdir)
ev = next(rr)
self.assertEquals("brain.Event:2", ev.file_version)
ev = next(rr)
ev_graph = tf.GraphDef()
ev_graph.ParseFromString(ev.graph_def)
self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
ev = next(rr)
ev_meta_graph = meta_graph_pb2.MetaGraphDef()
ev_meta_graph.ParseFromString(ev.meta_graph_def)
self.assertProtoEquals(meta_graph_def, ev_meta_graph)
self.assertProtoEquals(
sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
ev = next(rr)
# It is actually undeterministic whether SessionLog.START gets written
# before the summary or the checkpoint, but this works when run 10000 times.
self.assertEquals(123, ev.step)
self.assertEquals(tf.SessionLog.START, ev.session_log.status)
first = next(rr)
second = next(rr)
# It is undeterministic whether the value gets written before the checkpoint
# since they are on separate threads, so we check for both conditions.
if first.HasField("summary"):
self.assertProtoEquals("""value { tag: 'global_step/sec'
simple_value: 0.0 }""",
first.summary)
self.assertEquals(123, second.step)
self.assertEquals(tf.SessionLog.CHECKPOINT, second.session_log.status)
else:
self.assertEquals(123, first.step)
self.assertEquals(tf.SessionLog.CHECKPOINT, first.session_log.status)
self.assertProtoEquals("""value { tag: 'global_step/sec'
simple_value: 0.0 }""",
second.summary)
ev = next(rr)
self.assertEquals(tf.SessionLog.STOP, ev.session_log.status)
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with tf.Graph().as_default(), self.test_session() as sess:
v = tf.Variable([-12], name="global_step")
sav = tf.train.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, v.eval()[0])
开发者ID:KalraA,项目名称:tensorflow,代码行数:60,代码来源:supervisor_test.py
示例16: testVirtualCluster
def testVirtualCluster(self):
with ops.Graph().as_default() as g:
a = random_ops.random_uniform(shape=())
b = random_ops.random_uniform(shape=())
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
device_properties = device_properties_pb2.DeviceProperties(
type='GPU',
frequency=1000,
num_cores=60,
environment={
'architecture': '7'
})
named_device = device_properties_pb2.NamedDevice(
properties=device_properties, name='/GPU:0')
grappler_cluster = cluster.Cluster(devices=[named_device])
op_perfs, run_time, _ = grappler_cluster.MeasureCosts(grappler_item)
self.assertGreater(run_time, 0)
self.assertEqual(len(op_perfs), 15)
estimated_perf = grappler_cluster.EstimatePerformance(named_device)
self.assertEqual(7680.0, estimated_perf)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:cluster_test.py
示例17: testFromGenerator
def testFromGenerator(self):
test_cases = [{
'tensor': 0,
'shape': tensor_shape.TensorShape([])
}, {
'tensor': np.array([1, 2, 3]),
'shape': tensor_shape.TensorShape([3])
}, {
'tensor': np.array([[1, 2, 3]]),
'shape': tensor_shape.TensorShape([1, 3])
}]
for test_case in test_cases:
def make_generator(tensor):
def generator():
yield tensor
return generator
with ops.Graph().as_default() as g:
dataset = dataset_ops.Dataset.from_generator(
make_generator(test_case['tensor']),
dtypes.int64,
output_shapes=test_case['shape'])
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(get_next)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_properties = grappler_item.GetOpProperties()
self.assertEqual(test_case['shape'],
op_properties['IteratorGetNext'][0].shape)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:35,代码来源:datasets_test.py
示例18: GetOptimizedGraph
def GetOptimizedGraph():
mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph())
config = config_pb2.ConfigProto()
config.graph_options.rewrite_options.CopyFrom(
rewriter_config_pb2.RewriterConfig(
constant_folding=rewriter_config_pb2.RewriterConfig.OFF,
memory_optimization=rewriter_config_pb2.RewriterConfig.MANUAL))
return tf_optimizer.OptimizeGraph(config, mg)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:8,代码来源:while_v2_test.py
示例19: testImportantOps
def testImportantOps(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
train_op.append(c)
mg = meta_graph.create_meta_graph_def(graph=g)
grappler_item = item.Item(mg)
op_list = grappler_item.IdentifyImportantOps()
self.assertEqual([b'Const', b'Const_1', b'add'], op_list)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:11,代码来源:item_test.py
示例20: testInvalidItem
def testInvalidItem(self):
with ops.Graph().as_default() as g:
a = constant_op.constant(10)
b = constant_op.constant(20)
c = a + b # pylint: disable=unused-variable
mg = meta_graph.create_meta_graph_def(graph=g)
# The train op isn't specified: this should raise an InvalidArgumentError
# exception.
with self.assertRaises(errors_impl.InvalidArgumentError):
item.Item(mg)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:11,代码来源:item_test.py
注:本文中的tensorflow.python.framework.meta_graph.create_meta_graph_def函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论