本文整理汇总了Python中tensorflow.python.framework.meta_graph.export_scoped_meta_graph函数的典型用法代码示例。如果您正苦于以下问题:Python export_scoped_meta_graph函数的具体用法?Python export_scoped_meta_graph怎么用?Python export_scoped_meta_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了export_scoped_meta_graph函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _testExportImportAcrossScopes
def _testExportImportAcrossScopes(self, graph_fn):
"""Tests export and importing a graph across scopes.
Args:
graph_fn: A closure that creates a graph on the current scope.
"""
with ops.Graph().as_default() as original_graph:
with variable_scope.variable_scope("dropA/dropB/keepA"):
graph_fn()
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
graph=original_graph,
export_scope="dropA/dropB")[0]
with ops.Graph().as_default() as imported_graph:
meta_graph.import_scoped_meta_graph(
exported_meta_graph_def,
import_scope="importA")
with ops.Graph().as_default() as expected_graph:
with variable_scope.variable_scope("importA/keepA"):
graph_fn()
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
self.assertProtoEquals(expected, result)
开发者ID:SylChan,项目名称:tensorflow,代码行数:25,代码来源:meta_graph_test.py
示例2: testScopedImportWithSelectedCollections
def testScopedImportWithSelectedCollections(self):
meta_graph_filename = os.path.join(
_TestDir("selected_collections_import"), "meta_graph.pb")
graph = ops.Graph()
# Add a variable to populate two collections. The functionality tested is
# not specific to variables, but using variables in the test is convenient.
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True)
self.assertTrue(
all([
graph.get_collection(key)
for key in
[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
]))
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
def _test_import(include_collection_keys, omit_collection_keys):
assert set(include_collection_keys).isdisjoint(omit_collection_keys)
newgraph = ops.Graph()
import_scope = "some_scope_name"
def _restore_collections_predicate(collection_key):
return (collection_key in include_collection_keys and
collection_key not in omit_collection_keys)
meta_graph.import_scoped_meta_graph(
meta_graph_filename,
graph=newgraph,
import_scope=import_scope,
restore_collections_predicate=_restore_collections_predicate)
collection_values = [
newgraph.get_collection(name=key, scope=import_scope)
for key in include_collection_keys
]
self.assertTrue(all(collection_values))
collection_values = [
newgraph.get_collection(name=key, scope=import_scope)
for key in omit_collection_keys
]
self.assertFalse(any(collection_values))
_test_import(
include_collection_keys=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
],
omit_collection_keys=[])
_test_import(
include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
_test_import(
include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
_test_import(
include_collection_keys=[],
omit_collection_keys=[
ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:59,代码来源:meta_graph_test.py
示例3: testClearDevices
def testClearDevices(self):
graph1 = ops.Graph()
with graph1.as_default():
with ops.device("/device:CPU:0"):
a = variables.Variable(
constant_op.constant(
1.0, shape=[2, 2]), name="a")
with ops.device("/job:ps/replica:0/task:0/gpu:0"):
b = variables.Variable(
constant_op.constant(
2.0, shape=[2, 2]), name="b")
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
math_ops.matmul(a, b, name="matmul")
self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
str(graph1.as_graph_element("b").device))
self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
str(graph1.as_graph_element("matmul").device))
# Verifies that devices are cleared on export.
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
graph=graph1, clear_devices=True)
graph2 = ops.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
# Verifies that devices are cleared on export when passing in graph_def.
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
graph_def=graph1.as_graph_def(), clear_devices=True)
graph2 = ops.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
# Verifies that devices are cleared on import.
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
graph=graph1, clear_devices=False)
graph2 = ops.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:55,代码来源:meta_graph_test.py
示例4: testMetricsCollection
def testMetricsCollection(self):
def _enqueue_vector(sess, queue, values, shape=None):
if not shape:
shape = (1, len(values))
dtype = queue.dtypes[0]
sess.run(
queue.enqueue(constant_op.constant(
values, dtype=dtype, shape=shape)))
meta_graph_filename = os.path.join(
_TestDir("metrics_export"), "meta_graph.pb")
graph = ops.Graph()
with self.session(graph=graph) as sess:
values_queue = data_flow_ops.FIFOQueue(
4, dtypes.float32, shapes=(1, 2))
_enqueue_vector(sess, values_queue, [0, 1])
_enqueue_vector(sess, values_queue, [-4.2, 9.1])
_enqueue_vector(sess, values_queue, [6.5, 0])
_enqueue_vector(sess, values_queue, [-3.2, 4.0])
values = values_queue.dequeue()
_, update_op = metrics.mean(values)
initializer = variables.local_variables_initializer()
self.evaluate(initializer)
self.evaluate(update_op)
meta_graph.export_scoped_meta_graph(
filename=meta_graph_filename, graph=graph)
# Verifies that importing a meta_graph with LOCAL_VARIABLES collection
# works correctly.
graph = ops.Graph()
with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(meta_graph_filename)
initializer = variables.local_variables_initializer()
self.evaluate(initializer)
# Verifies that importing an old meta_graph where "local_variables"
# collection is of node_list type works, but cannot build initializer
# with the collection.
graph = ops.Graph()
with self.session(graph=graph) as sess:
meta_graph.import_scoped_meta_graph(
test.test_src_dir_path(
"python/framework/testdata/metrics_export_meta_graph.pb"))
self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
2)
with self.assertRaisesRegexp(
AttributeError, "'Tensor' object has no attribute 'initializer'"):
initializer = variables.local_variables_initializer()
开发者ID:aeverall,项目名称:tensorflow,代码行数:53,代码来源:meta_graph_test.py
示例5: testNoVariables
def testNoVariables(self):
test_dir = _TestDir("no_variables")
filename = os.path.join(test_dir, "metafile")
input_feed_value = -10 # Arbitrary input value for feed_dict.
orig_graph = ops.Graph()
with self.session(graph=orig_graph) as sess:
# Create a minimal graph with zero variables.
input_tensor = array_ops.placeholder(
dtypes.float32, shape=[], name="input")
offset = constant_op.constant(42, dtype=dtypes.float32, name="offset")
output_tensor = math_ops.add(input_tensor, offset, name="add_offset")
# Add input and output tensors to graph collections.
ops.add_to_collection("input_tensor", input_tensor)
ops.add_to_collection("output_tensor", output_tensor)
output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
self.assertEqual(output_value, 32)
# Generates MetaGraphDef.
meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
collection_list=["input_tensor", "output_tensor"],
saver_def=None)
self.assertTrue(meta_graph_def.HasField("meta_info_def"))
self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
"")
self.assertEqual({}, var_list)
# Create a clean graph and import the MetaGraphDef nodes.
new_graph = ops.Graph()
with self.session(graph=new_graph) as sess:
# Import the previously export meta graph.
meta_graph.import_scoped_meta_graph(filename)
# Re-exports the current graph state for comparison to the original.
new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
"_new")
test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
new_meta_graph_def)
# Ensures that we can still get a reference to our graph collections.
new_input_tensor = ops.get_collection("input_tensor")[0]
new_output_tensor = ops.get_collection("output_tensor")[0]
# Verifies that the new graph computes the same result as the original.
new_output_value = sess.run(new_output_tensor,
{new_input_tensor: input_feed_value})
self.assertEqual(new_output_value, output_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:52,代码来源:meta_graph_test.py
示例6: testDefaultAttrStripping
def testDefaultAttrStripping(self):
"""Verifies that default attributes are stripped from a graph def."""
# Complex Op has 2 attributes with defaults:
# o "T" : float32.
# o "Tout" : complex64.
# When inputs to the Complex Op are float32 instances, "T" maps to float32
# and "Tout" maps to complex64. Since these attr values map to their
# defaults, they must be stripped unless stripping of default attrs is
# disabled.
with self.cached_session():
real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
# strip_default_attrs is enabled.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=True)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertNotIn("T", node_def.attr)
self.assertNotIn("Tout", node_def.attr)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
# strip_default_attrs is disabled.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=False)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertIn("T", node_def.attr)
self.assertIn("Tout", node_def.attr)
self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)
# When inputs to the Complex Op are float64 instances, "T" maps to float64
# and "Tout" maps to complex128. Since these attr values don't map to their
# defaults, they must not be stripped.
with self.session(graph=ops.Graph()):
real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=True)
node_def = test_util.get_node_def_from_graph("complex",
meta_graph_def.graph_def)
self.assertEqual(node_def.attr["T"].type, dtypes.float64)
self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
开发者ID:aeverall,项目名称:tensorflow,代码行数:51,代码来源:meta_graph_test.py
示例7: testWhileLoopGradients
def testWhileLoopGradients(self):
# Create a simple while loop.
with ops.Graph().as_default():
with ops.name_scope("export"):
var = variables.Variable(0.)
var_name = var.name
_, output = control_flow_ops.while_loop(
lambda i, x: i < 5,
lambda i, x: (i + 1, x + math_ops.cast(i, dtypes.float32)),
[0, var])
output_name = output.name
# Generate a MetaGraphDef containing the while loop with an export scope.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
export_scope="export")
# Build and run the gradients of the while loop. We use this below to
# verify that the gradients are correct with the imported MetaGraphDef.
init_op = variables.global_variables_initializer()
grad = gradients_impl.gradients([output], [var])
with session.Session() as sess:
self.evaluate(init_op)
expected_grad_value = self.evaluate(grad)
# Restore the MetaGraphDef into a new Graph with an import scope.
with ops.Graph().as_default():
meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import")
# Re-export and make sure we get the same MetaGraphDef.
new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
export_scope="import")
test_util.assert_meta_graph_protos_equal(
self, meta_graph_def, new_meta_graph_def)
# Make sure we can still build gradients and get the same result.
def new_name(tensor_name):
base_tensor_name = tensor_name.replace("export/", "")
return "import/" + base_tensor_name
var = ops.get_default_graph().get_tensor_by_name(new_name(var_name))
output = ops.get_default_graph().get_tensor_by_name(new_name(output_name))
grad = gradients_impl.gradients([output], [var])
init_op = variables.global_variables_initializer()
with session.Session() as sess:
self.evaluate(init_op)
actual_grad_value = self.evaluate(grad)
self.assertEqual(expected_grad_value, actual_grad_value)
开发者ID:aeverall,项目名称:tensorflow,代码行数:50,代码来源:meta_graph_test.py
示例8: _testExportImportAcrossScopes
def _testExportImportAcrossScopes(self, graph_fn, use_resource):
"""Tests export and importing a graph across scopes.
Args:
graph_fn: A closure that creates a graph on the current scope.
use_resource: A bool indicating whether or not to use ResourceVariables.
"""
with ops.Graph().as_default() as original_graph:
with variable_scope.variable_scope("dropA/dropB/keepA"):
graph_fn(use_resource=use_resource)
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
graph=original_graph,
export_scope="dropA/dropB")[0]
with ops.Graph().as_default() as imported_graph:
meta_graph.import_scoped_meta_graph(
exported_meta_graph_def,
import_scope="importA")
with ops.Graph().as_default() as expected_graph:
with variable_scope.variable_scope("importA/keepA"):
graph_fn(use_resource=use_resource)
if use_resource:
# Bringing in a collection that contains ResourceVariables adds ops
# to the graph, so mimic the same behavior.
for collection_key in sorted([
ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.TRAINABLE_VARIABLES,
]):
for var in expected_graph.get_collection(collection_key):
var._read_variable_op()
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
if use_resource:
# Clear all shared_name attributes before comparing, since they are
# supposed to be orthogonal to scopes.
for meta_graph_def in [result, expected]:
for node in meta_graph_def.graph_def.node:
shared_name_attr = "shared_name"
shared_name_value = node.attr.get(shared_name_attr, None)
if shared_name_value and shared_name_value.HasField("s"):
if shared_name_value.s:
node.attr[shared_name_attr].s = b""
self.assertProtoEquals(expected, result)
开发者ID:autodrive,项目名称:tensorflow,代码行数:48,代码来源:meta_graph_test.py
示例9: testClearDevices
def testClearDevices(self):
graph1 = tf.Graph()
with graph1.as_default():
with tf.device("/device:CPU:0"):
a = tf.Variable(tf.constant(1.0, shape=[2, 2]), name="a")
with tf.device("/job:ps/replica:0/task:0/gpu:0"):
b = tf.Variable(tf.constant(2.0, shape=[2, 2]), name="b")
with tf.device("/job:localhost/replica:0/task:0/cpu:0"):
tf.matmul(a, b, name="matmul")
self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
str(graph1.as_graph_element("b").device))
self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
str(graph1.as_graph_element("matmul").device))
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
graph2 = tf.Graph()
with graph2.as_default():
meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)
self.assertEqual("", str(graph2.as_graph_element("a").device))
self.assertEqual("", str(graph2.as_graph_element("b").device))
self.assertEqual("", str(graph2.as_graph_element("matmul").device))
开发者ID:caikehe,项目名称:tensorflow,代码行数:25,代码来源:meta_graph_test.py
示例10: testPotentialCycle
def testPotentialCycle(self):
graph1 = ops.Graph()
with graph1.as_default():
a = constant_op.constant(1.0, shape=[2, 2])
b = constant_op.constant(2.0, shape=[2, 2])
matmul = math_ops.matmul(a, b)
with ops.name_scope("hidden1"):
c = nn_ops.relu(matmul)
d = constant_op.constant(3.0, shape=[2, 2])
matmul = math_ops.matmul(c, d)
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
export_scope="hidden1", graph=graph1)
graph2 = ops.Graph()
with graph2.as_default():
with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
meta_graph.import_scoped_meta_graph(
orig_meta_graph, import_scope="new_hidden1")
meta_graph.import_scoped_meta_graph(
orig_meta_graph,
import_scope="new_hidden1",
input_map={
"$unbound_inputs_MatMul": constant_op.constant(
4.0, shape=[2, 2])
})
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:27,代码来源:meta_graph_test.py
示例11: testDefaultAttrStrippingNestedFunctions
def testDefaultAttrStrippingNestedFunctions(self):
"""Verifies that default attributes are stripped from function node defs."""
with self.cached_session():
@function.Defun(dtypes.float32, dtypes.float32)
def f0(i, j):
return math_ops.complex(i, j, name="double_nested_complex")
@function.Defun(dtypes.float32, dtypes.float32)
def f1(i, j):
return f0(i, j)
_ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
graph_def=ops.get_default_graph().as_graph_def(),
strip_default_attrs=True)
double_nested_complex_node_def = None
for function_def in meta_graph_def.graph_def.library.function:
for node_def in function_def.node_def:
if node_def.name.startswith("double_nested_complex"):
double_nested_complex_node_def = node_def
break
if double_nested_complex_node_def:
break
self.assertIsNotNone(double_nested_complex_node_def)
self.assertNotIn("T", double_nested_complex_node_def.attr)
self.assertNotIn("Tout", double_nested_complex_node_def.attr)
self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:meta_graph_test.py
示例12: testSummaryWithFamilyMetaGraphExport
def testSummaryWithFamilyMetaGraphExport(self):
with ops.name_scope('outer'):
i = constant_op.constant(11)
summ = summary_lib.scalar('inner', i)
self.assertEquals(summ.op.name, 'outer/inner')
summ_f = summary_lib.scalar('inner', i, family='family')
self.assertEquals(summ_f.op.name, 'outer/family/inner')
metagraph_def, _ = meta_graph.export_scoped_meta_graph(export_scope='outer')
with ops.Graph().as_default() as g:
meta_graph.import_scoped_meta_graph(metagraph_def, graph=g,
import_scope='new_outer')
# The summaries should exist, but with outer scope renamed.
new_summ = g.get_tensor_by_name('new_outer/inner:0')
new_summ_f = g.get_tensor_by_name('new_outer/family/inner:0')
# However, the tags are unaffected.
with self.cached_session() as s:
new_summ_str, new_summ_f_str = s.run([new_summ, new_summ_f])
new_summ_pb = summary_pb2.Summary()
new_summ_pb.ParseFromString(new_summ_str)
self.assertEquals('outer/inner', new_summ_pb.value[0].tag)
new_summ_f_pb = summary_pb2.Summary()
new_summ_f_pb.ParseFromString(new_summ_f_str)
self.assertEquals('family/outer/family/inner',
new_summ_f_pb.value[0].tag)
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:27,代码来源:summary_test.py
示例13: testImportsUsingSameScopeName
def testImportsUsingSameScopeName(self):
with ops.Graph().as_default():
variables.Variable(0, name="v")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
with ops.Graph().as_default():
for suffix in ["", "_1"]:
imported_variables = meta_graph.import_scoped_meta_graph(
meta_graph_def, import_scope="s")
self.assertEqual(len(imported_variables), 1)
self.assertEqual(list(imported_variables.keys())[0], "v:0")
self.assertEqual(list(imported_variables.values())[0].name,
"s" + suffix + "/v:0")
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:meta_graph_test.py
示例14: _testExportImportAcrossScopes
def _testExportImportAcrossScopes(self, graph_fn, use_resource):
"""Tests export and importing a graph across scopes.
Args:
graph_fn: A closure that creates a graph on the current scope.
use_resource: A bool indicating whether or not to use ResourceVariables.
"""
with ops.Graph().as_default() as original_graph:
with variable_scope.variable_scope("dropA/dropB/keepA"):
graph_fn(use_resource=use_resource)
exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
graph=original_graph,
export_scope="dropA/dropB")[0]
with ops.Graph().as_default() as imported_graph:
meta_graph.import_scoped_meta_graph(
exported_meta_graph_def,
import_scope="importA")
with ops.Graph().as_default() as expected_graph:
with variable_scope.variable_scope("importA/keepA"):
graph_fn(use_resource=use_resource)
result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
if use_resource:
# Clear all shared_name attributes before comparing, since they are
# orthogonal to scopes and are not updated on export/import.
for meta_graph_def in [result, expected]:
for node in meta_graph_def.graph_def.node:
shared_name_attr = "shared_name"
shared_name_value = node.attr.get(shared_name_attr, None)
if shared_name_value and shared_name_value.HasField("s"):
if shared_name_value.s:
node.attr[shared_name_attr].s = b""
test_util.assert_meta_graph_protos_equal(self, expected, result)
开发者ID:aeverall,项目名称:tensorflow,代码行数:38,代码来源:meta_graph_test.py
示例15: testScopedImportUnderNameScope
def testScopedImportUnderNameScope(self):
graph = ops.Graph()
with graph.as_default():
variables.Variable(initial_value=1.0, trainable=True, name="myvar")
meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)
graph = ops.Graph()
with graph.as_default():
with ops.name_scope("foo"):
imported_variables = meta_graph.import_scoped_meta_graph(
meta_graph_def, import_scope="bar")
self.assertEqual(len(imported_variables), 1)
self.assertEqual(list(imported_variables.values())[0].name,
"foo/bar/myvar:0")
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:14,代码来源:meta_graph_test.py
示例16: doTestExportNestedNames
def doTestExportNestedNames(self, use_resource=False):
graph1 = ops.Graph()
with graph1.as_default():
with ops.name_scope("hidden1/hidden2/hidden3"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
if use_resource:
weights1 = variables.Variable(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
biases1 = resource_variable_ops.ResourceVariable(
[0.1] * 3, name="biases")
else:
biases1 = variables.Variable([0.1] * 3, name="biases")
weights1 = variables.Variable(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
export_scope="hidden1/hidden2", graph=graph1)
var_names = [v.name for _, v in var_list.items()]
self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
sorted(var_list.keys()))
self.assertEqual([
"hidden1/hidden2/hidden3/biases:0", "hidden1/hidden2/hidden3/weights:0"
], sorted(var_names))
for node in orig_meta_graph.graph_def.node:
self.assertTrue(node.name.startswith("hidden3"))
graph2 = ops.Graph()
new_var_list = meta_graph.import_scoped_meta_graph(
orig_meta_graph, import_scope="new_hidden1/new_hidden2", graph=graph2)
self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
sorted(new_var_list.keys()))
new_var_names = [v.name for _, v in new_var_list.items()]
self.assertEqual([
"new_hidden1/new_hidden2/hidden3/biases:0",
"new_hidden1/new_hidden2/hidden3/weights:0"
], sorted(new_var_names))
nodes = [
"new_hidden1/new_hidden2/hidden3/biases/Assign",
"new_hidden1/new_hidden2/hidden3/weights/Assign"
]
expected = [
b"loc:@new_hidden1/new_hidden2/hidden3/biases",
b"loc:@new_hidden1/new_hidden2/hidden3/weights"
]
for n, e in zip(nodes, expected):
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:49,代码来源:meta_graph_test.py
示例17: _testScopedImportWithQueue
def _testScopedImportWithQueue(self, test_dir, exported_filename,
new_exported_filename):
graph = tf.Graph()
meta_graph.import_scoped_meta_graph(
os.path.join(test_dir, exported_filename),
graph=graph,
import_scope="new_queue1")
graph.as_graph_element("new_queue1/dequeue:0")
graph.as_graph_element("new_queue1/close")
with graph.as_default():
new_meta_graph, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, new_exported_filename),
graph=graph, export_scope="new_queue1")
return new_meta_graph
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:15,代码来源:meta_graph_test.py
示例18: _testScopedExportWithQueue
def _testScopedExportWithQueue(self, test_dir, exported_filename):
graph = tf.Graph()
with graph.as_default():
with tf.name_scope("queue1"):
input_queue = tf.FIFOQueue(10, tf.float32)
enqueue = input_queue.enqueue((9876), name="enqueue")
close = input_queue.close(name="close")
qr = tf.train.QueueRunner(input_queue, [enqueue], close)
tf.train.add_queue_runner(qr)
input_queue.dequeue(name="dequeue")
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
filename=os.path.join(test_dir, exported_filename),
graph=tf.get_default_graph(), export_scope="queue1")
return orig_meta_graph
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:16,代码来源:meta_graph_test.py
示例19: testVariableObjectsAreSharedAmongCollections
def testVariableObjectsAreSharedAmongCollections(self):
with ops.Graph().as_default() as graph1:
v = variables.Variable(3.0)
# A single instance of Variable is shared among the collections:
global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual(len(global_vars), 1)
self.assertEqual(len(trainable_vars), 1)
self.assertIs(global_vars[0], trainable_vars[0])
self.assertIs(v, global_vars[0])
orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
del graph1 # To avoid accidental references in code involving graph2.
with ops.Graph().as_default() as graph2:
meta_graph.import_scoped_meta_graph(orig_meta_graph)
global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual(len(global_vars), 1)
self.assertEqual(len(trainable_vars), 1)
# A single instance of Variable is shared among the collections:
self.assertIs(global_vars[0], trainable_vars[0])
开发者ID:aeverall,项目名称:tensorflow,代码行数:22,代码来源:meta_graph_test.py
示例20: testImportWhileLoopInWhileLoop
def testImportWhileLoopInWhileLoop(self):
# Create a simple while loop.
with ops.Graph().as_default():
var = variables.Variable(0.0)
_, output = control_flow_ops.while_loop(lambda i, x: i < 5,
lambda i, x: (i + 1, x * 2.0),
[0, var])
output_name = output.name
# Generate a MetaGraphDef containing the while loop with an export scope.
meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
# Restore the MetaGraphDef in a while loop in a new graph.
with ops.Graph().as_default():
def body(i, _):
meta_graph.import_scoped_meta_graph(meta_graph_def)
return i + 1, ops.get_default_graph().get_tensor_by_name(output_name)
_, x = control_flow_ops.while_loop(lambda i, x: i < 2, body, [0, 0.0],
name="")
with session.Session() as sess:
self.evaluate(variables.global_variables_initializer())
self.evaluate(x)
开发者ID:aeverall,项目名称:tensorflow,代码行数:24,代码来源:meta_graph_test.py
注:本文中的tensorflow.python.framework.meta_graph.export_scoped_meta_graph函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论