本文整理汇总了Python中tensorflow.python.debug.lib.debug_utils.watch_graph函数的典型用法代码示例。如果您正苦于以下问题:Python watch_graph函数的具体用法?Python watch_graph怎么用?Python watch_graph使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了watch_graph函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: createAndRunGraphWithWhileLoop
def createAndRunGraphWithWhileLoop(self):
"""Create and run a TensorFlow Graph with a while loop to generate dumps."""
self.dump_root = self.get_temp_dir()
self.curr_file_path = os.path.abspath(
tf_inspect.getfile(tf_inspect.currentframe()))
# Run a simple TF graph to generate some debug dumps that can be used in
# source annotation.
with session.Session() as sess:
loop_body = lambda i: math_ops.add(i, 2)
self.traceback_first_line = line_number_above()
loop_cond = lambda i: math_ops.less(i, 16)
i = constant_op.constant(10, name="i")
loop = control_flow_ops.while_loop(loop_cond, loop_body, [i])
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
run_metadata = config_pb2.RunMetadata()
sess.run(loop, options=run_options, run_metadata=run_metadata)
self.dump = debug_data.DebugDumpDir(
self.dump_root, partition_graphs=run_metadata.partition_graphs)
self.dump.set_python_graph(sess.graph)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:27,代码来源:source_utils_test.py
示例2: before_run
def before_run(self, run_context):
if not self._wrapper_initialized:
dumping_wrapper.DumpingDebugWrapperSession.__init__(
self,
run_context.session,
self._session_root,
watch_fn=self._watch_fn,
log_usage=self._log_usage)
self._wrapper_initialized = True
self._run_call_count += 1
(debug_urls, debug_ops, node_name_regex_whitelist,
op_type_regex_whitelist) = self._prepare_run_watch_config(
run_context.original_args.fetches, run_context.original_args.feed_dict)
run_options = config_pb2.RunOptions()
debug_utils.watch_graph(
run_options,
run_context.session.graph,
debug_urls=debug_urls,
debug_ops=debug_ops,
node_name_regex_whitelist=node_name_regex_whitelist,
op_type_regex_whitelist=op_type_regex_whitelist)
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)
return run_args
开发者ID:brainwy12,项目名称:tensorflow,代码行数:27,代码来源:hooks.py
示例3: testWatchGraph_allNodes
def testWatchGraph_allNodes(self):
debug_utils.watch_graph(
self._run_options,
self._graph,
debug_ops=["DebugIdentity", "DebugNanCount"],
debug_urls="file:///tmp/tfdbg_1")
debug_watch_opts = self._run_options.debug_options.debug_tensor_watch_opts
self.assertEqual(self._expected_num_nodes, len(debug_watch_opts))
# Verify that each of the nodes in the graph with output tensors in the
# graph have debug tensor watch.
node_names = self._verify_watches(debug_watch_opts, 0,
["DebugIdentity", "DebugNanCount"],
["file:///tmp/tfdbg_1"])
# Verify the node names.
self.assertTrue("a1_init" in node_names)
self.assertTrue("a1" in node_names)
self.assertTrue("a1/Assign" in node_names)
self.assertTrue("a1/read" in node_names)
self.assertTrue("b_init" in node_names)
self.assertTrue("b" in node_names)
self.assertTrue("b/Assign" in node_names)
self.assertTrue("b/read" in node_names)
self.assertTrue("c" in node_names)
self.assertTrue("p1" in node_names)
self.assertTrue("s" in node_names)
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:debug_utils_test.py
示例4: before_run
def before_run(self, run_context):
if not self._session_wrapper:
self._session_wrapper = dumping_wrapper.DumpingDebugWrapperSession(
run_context.session,
self._session_root,
watch_fn=self._watch_fn,
thread_name_filter=self._thread_name_filter,
log_usage=self._log_usage)
self._session_wrapper.increment_run_call_count()
# pylint: disable=protected-access
debug_urls, watch_options = self._session_wrapper._prepare_run_watch_config(
run_context.original_args.fetches, run_context.original_args.feed_dict)
# pylint: enable=protected-access
run_options = config_pb2.RunOptions()
debug_utils.watch_graph(
run_options,
run_context.session.graph,
debug_urls=debug_urls,
debug_ops=watch_options.debug_ops,
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
tolerate_debug_op_creation_failures=(
watch_options.tolerate_debug_op_creation_failures))
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)
return run_args
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:30,代码来源:hooks.py
示例5: before_run
def before_run(self, run_context):
if not self._wrapper_initialized:
# TODO(cais): Make this hook have a DumpingDebugWrapperSession property
# instead of subclassing DumpingDebugWrapperSession.
dumping_wrapper.DumpingDebugWrapperSession.__init__(
self,
run_context.session,
self._session_root,
watch_fn=self._watch_fn,
thread_name_filter=self._thread_name_filter,
log_usage=self._log_usage)
self._wrapper_initialized = True
self._run_call_count += 1
debug_urls, watch_options = self._prepare_run_watch_config(
run_context.original_args.fetches, run_context.original_args.feed_dict)
run_options = config_pb2.RunOptions()
debug_utils.watch_graph(
run_options,
run_context.session.graph,
debug_urls=debug_urls,
debug_ops=watch_options.debug_ops,
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
tolerate_debug_op_creation_failures=(
watch_options.tolerate_debug_op_creation_failures))
run_args = session_run_hook.SessionRunArgs(
None, feed_dict=None, options=run_options)
return run_args
开发者ID:finardi,项目名称:tensorflow,代码行数:32,代码来源:hooks.py
示例6: testToggleBreakpointsWorks
def testToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.VariableV1(50.0, name="v_1")
v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")
sess.run([v_1.initializer, v_2.initializer])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
if i in (0, 2):
# Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
self._server_1.request_watch(
"delta_1", 0, "DebugIdentity", breakpoint=True)
self._server_1.request_watch(
"delta_2", 0, "DebugIdentity", breakpoint=True)
else:
# Disable the breakpoint in runs 1 and 3.
self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
output = sess.run([inc_v_1, inc_v_2],
options=run_options, run_metadata=run_metadata)
self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)
if i in (0, 2):
# During runs 0 and 2, the server should have received the published
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
# unblocked by EventReply reponses from the server.
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
self.assertAllClose(
[-5.0],
self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
# After the runs, the server should have properly registered the
# breakpoints due to the request_unwatch calls.
self.assertSetEqual({("delta_1", 0, "DebugIdentity"),
("delta_2", 0, "DebugIdentity")},
self._server_1.breakpoints)
else:
# After the end of runs 1 and 3, the server has received the requests
# to disable the breakpoint at delta:0:DebugIdentity.
self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:57,代码来源:session_debug_grpc_test.py
示例7: _decorate_options_for_debug
def _decorate_options_for_debug(self, options, graph):
"""Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging.
Args:
options: (config_pb2.RunOptions) The RunOptions instance to be modified.
graph: A TensorFlow Graph object.
"""
debug_utils.watch_graph(
options, graph, debug_urls=self._get_run_debug_urls())
options.output_partition_graphs = True
开发者ID:aravindvcyber,项目名称:tensorflow,代码行数:11,代码来源:hooks.py
示例8: testWatchGraph_tensorDTypeWhitelist
def testWatchGraph_tensorDTypeWhitelist(self):
debug_utils.watch_graph(
self._run_options,
self._graph,
debug_urls="file:///tmp/tfdbg_1",
tensor_dtype_regex_whitelist=".*_ref")
node_names = self._verify_watches(
self._run_options.debug_options.debug_tensor_watch_opts, 0,
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names)
开发者ID:aeverall,项目名称:tensorflow,代码行数:11,代码来源:debug_utils_test.py
示例9: testWatchGraph_opTypeWhitelist
def testWatchGraph_opTypeWhitelist(self):
debug_utils.watch_graph(
self._run_options,
self._graph,
debug_urls="file:///tmp/tfdbg_1",
op_type_regex_whitelist="(Variable|MatMul)")
node_names = self._verify_watches(
self._run_options.debug_options.debug_tensor_watch_opts, 0,
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names))
开发者ID:aeverall,项目名称:tensorflow,代码行数:11,代码来源:debug_utils_test.py
示例10: testWatchGraph_nodeNameAndOpTypeWhitelists
def testWatchGraph_nodeNameAndOpTypeWhitelists(self):
debug_utils.watch_graph(
self._run_options,
self._graph,
debug_urls="file:///tmp/tfdbg_1",
node_name_regex_whitelist="([a-z]+1$)",
op_type_regex_whitelist="(MatMul)")
node_names = self._verify_watches(
self._run_options.debug_options.debug_tensor_watch_opts, 0,
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(["p1"], node_names)
开发者ID:aeverall,项目名称:tensorflow,代码行数:12,代码来源:debug_utils_test.py
示例11: _decorate_options_for_debug
def _decorate_options_for_debug(self, options, graph, watch_options):
"""Modify RunOptions.debug_options.debug_tensor_watch_opts for debugging."""
debug_utils.watch_graph(
options,
graph,
debug_urls=self._get_run_debug_urls(),
node_name_regex_whitelist=watch_options.node_name_regex_whitelist,
op_type_regex_whitelist=watch_options.op_type_regex_whitelist,
tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist,
tolerate_debug_op_creation_failures=(
watch_options.tolerate_debug_op_creation_failures))
options.output_partition_graphs = True
开发者ID:finardi,项目名称:tensorflow,代码行数:12,代码来源:hooks.py
示例12: testWatchGraph_nodeNameWhitelist
def testWatchGraph_nodeNameWhitelist(self):
debug_utils.watch_graph(
self._run_options,
self._graph,
debug_urls="file:///tmp/tfdbg_1",
node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)")
node_names = self._verify_watches(
self._run_options.debug_options.debug_tensor_watch_opts, 0,
["DebugIdentity"], ["file:///tmp/tfdbg_1"])
self.assertEqual(
sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]),
sorted(node_names))
开发者ID:aeverall,项目名称:tensorflow,代码行数:13,代码来源:debug_utils_test.py
示例13: testGradientsValuesFromDumpWorks
def testGradientsValuesFromDumpWorks(self):
y = math_ops.add(self.w, -1.0, name="y")
z = math_ops.square(y, name="z")
grad_debugger = debug_gradients.GradientsDebugger()
with grad_debugger.watch_gradients_by_tensors(
self.sess.graph, [self.w, self.u, y]):
train_op = gradient_descent.GradientDescentOptimizer(0.1).minimize(z)
self.sess.run(variables.global_variables_initializer())
run_options = config_pb2.RunOptions(output_partition_graphs=True)
dump_dir = tempfile.mkdtemp()
debug_url = "file://" + dump_dir
debug_utils.watch_graph(
run_options,
self.sess.graph,
debug_urls=debug_url)
run_metadata = config_pb2.RunMetadata()
self.assertAllClose(2.0, self.sess.run(self.u))
self.sess.run(train_op, options=run_options, run_metadata=run_metadata)
self.assertAllClose(-1.0, self.sess.run(self.u))
dump = debug_data.DebugDumpDir(
dump_dir, partition_graphs=run_metadata.partition_graphs)
dump.set_python_graph(self.sess.graph)
y_grad_values = debug_gradients.gradient_values_from_dump(
grad_debugger, y, dump)
self.assertEqual(1, len(y_grad_values))
self.assertAllClose(10.0, y_grad_values[0])
w_grad_values = debug_gradients.gradient_values_from_dump(
grad_debugger, self.w, dump)
self.assertEqual(1, len(w_grad_values))
self.assertAllClose(10.0, w_grad_values[0])
u_grad_values = debug_gradients.gradient_values_from_dump(
grad_debugger, self.u, dump)
self.assertEqual(1, len(u_grad_values))
self.assertAllClose(30.0, u_grad_values[0])
with self.assertRaisesRegexp(
LookupError,
r"This GradientsDebugger has not received any gradient tensor for "
r"x-tensor v:0"):
debug_gradients.gradient_values_from_dump(grad_debugger, self.v, dump)
# Cleanup.
shutil.rmtree(dump_dir)
开发者ID:Lin-jipeng,项目名称:tensorflow,代码行数:50,代码来源:debug_gradients_test.py
示例14: testToggleBreakpointWorks
def testToggleBreakpointWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
# N.B.: These requests will be fulfilled not in this debugged
# Session.run() invocation, but in the next one.
if i in (0, 2):
# Enable breakpoint at delta:0:DebugIdentity in runs 0 and 2.
self._server_1.request_watch(
"delta", 0, "DebugIdentity", breakpoint=True)
else:
# Disable the breakpoint in runs 1 and 3.
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
output = sess.run(inc_v, options=run_options, run_metadata=run_metadata)
self.assertAllClose(50.0 + 5.0 * (i + 1), output)
if i in (0, 2):
# After the end of runs 0 and 2, the server has received the requests
# to enable the breakpoint at delta:0:DebugIdentity. So the server
# should keep track of the correct breakpoints.
self.assertSetEqual({("delta", 0, "DebugIdentity")},
self._server_1.breakpoints)
else:
# During runs 1 and 3, the server should have received the published
# debug tensor delta:0:DebugIdentity. The breakpoint should have been
# unblocked by EventReply reponses from the server.
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
# After the runs, the server should have properly removed the
# breakpoints due to the request_unwatch calls.
self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:chdinh,项目名称:tensorflow,代码行数:48,代码来源:session_debug_grpc_test.py
示例15: testToggleWatchesOnCoreMetadata
def testToggleWatchesOnCoreMetadata(self):
(_, debug_server_url, _, server_thread,
server) = grpc_debug_test_server.start_server_on_separate_thread(
dump_to_filesystem=False,
toggle_watch_on_core_metadata=[("toggled_1", 0, "DebugIdentity"),
("toggled_2", 0, "DebugIdentity")])
self._servers_and_threads.append((server, server_thread))
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v_1 = variables.VariableV1(50.0, name="v_1")
v_2 = variables.VariableV1(-50.0, name="v_1")
# These two nodes have names that match those in the
# toggle_watch_on_core_metadata argument used when calling
# start_server_on_separate_thread().
toggled_1 = constant_op.constant(5.0, name="toggled_1")
toggled_2 = constant_op.constant(-5.0, name="toggled_2")
inc_v_1 = state_ops.assign_add(v_1, toggled_1, name="inc_v_1")
inc_v_2 = state_ops.assign_add(v_2, toggled_2, name="inc_v_2")
sess.run([v_1.initializer, v_2.initializer])
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[debug_server_url])
for i in xrange(4):
server.clear_data()
sess.run([inc_v_1, inc_v_2],
options=run_options, run_metadata=run_metadata)
if i % 2 == 0:
self.assertEqual(2, len(server.debug_tensor_values))
self.assertAllClose(
[5.0],
server.debug_tensor_values["toggled_1:0:DebugIdentity"])
self.assertAllClose(
[-5.0],
server.debug_tensor_values["toggled_2:0:DebugIdentity"])
else:
self.assertEqual(0, len(server.debug_tensor_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:46,代码来源:session_debug_grpc_test.py
示例16: testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
with session.Session(config=no_rewrite_session_config()) as sess:
v = variables.Variable(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)",
"DebugNumericSummary(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1])
for i in xrange(4):
self._server_1.clear_data()
# N.B.: These requests will be fulfilled not in this debugged
# Session.run() invocation, but in the next one.
if i % 2 == 0:
self._server_1.request_watch("delta", 0, "DebugIdentity")
self._server_1.request_unwatch("delta", 0, "DebugNumericSummary")
else:
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
self._server_1.request_watch("delta", 0, "DebugNumericSummary")
sess.run(inc_v, options=run_options, run_metadata=run_metadata)
if i == 0:
self.assertEqual(0, len(self._server_1.debug_tensor_values))
else:
self.assertEqual(1, len(self._server_1.debug_tensor_values))
if i % 2 == 1:
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
else:
self.assertAllClose(
[[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0,
0.0, 1.0, 0.0]],
self._server_1.debug_tensor_values[
"delta:0:DebugNumericSummary"])
开发者ID:chdinh,项目名称:tensorflow,代码行数:45,代码来源:session_debug_grpc_test.py
示例17: createAndRunGraphHelper
def createAndRunGraphHelper(self):
"""Create and run a TensorFlow Graph to generate debug dumps.
This is intentionally done in separate method, to make it easier to test
the stack-top mode of source annotation.
"""
self.dump_root = self.get_temp_dir()
self.curr_file_path = os.path.abspath(
tf_inspect.getfile(tf_inspect.currentframe()))
# Run a simple TF graph to generate some debug dumps that can be used in
# source annotation.
with session.Session() as sess:
self.u_init = constant_op.constant(
np.array([[5.0, 3.0], [-1.0, 0.0]]), shape=[2, 2], name="u_init")
self.u_init_line_number = line_number_above()
self.u = variables.Variable(self.u_init, name="u")
self.u_line_number = line_number_above()
self.v_init = constant_op.constant(
np.array([[2.0], [-1.0]]), shape=[2, 1], name="v_init")
self.v_init_line_number = line_number_above()
self.v = variables.Variable(self.v_init, name="v")
self.v_line_number = line_number_above()
self.w = math_ops.matmul(self.u, self.v, name="w")
self.w_line_number = line_number_above()
sess.run(self.u.initializer)
sess.run(self.v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options, sess.graph, debug_urls=["file://%s" % self.dump_root])
run_metadata = config_pb2.RunMetadata()
sess.run(self.w, options=run_options, run_metadata=run_metadata)
self.dump = debug_data.DebugDumpDir(
self.dump_root, partition_graphs=run_metadata.partition_graphs)
self.dump.set_python_graph(sess.graph)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:43,代码来源:source_utils_test.py
示例18: testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(v.initializer)
run_metadata = config_pb2.RunMetadata()
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(
run_options,
sess.graph,
debug_ops=["DebugIdentity(gated_grpc=true)"],
debug_urls=[self._debug_server_url_1, self._debug_server_url_2])
for i in xrange(4):
self._server_1.clear_data()
self._server_2.clear_data()
if i % 2 == 0:
self._server_1.request_watch("delta", 0, "DebugIdentity")
self._server_2.request_watch("v", 0, "DebugIdentity")
else:
self._server_1.request_unwatch("delta", 0, "DebugIdentity")
self._server_2.request_unwatch("v", 0, "DebugIdentity")
sess.run(inc_v, options=run_options, run_metadata=run_metadata)
if i % 2 == 0:
self.assertEqual(1, len(self._server_1.debug_tensor_values))
self.assertEqual(1, len(self._server_2.debug_tensor_values))
self.assertAllClose(
[5.0],
self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
self.assertAllClose(
[50 + 5.0 * i],
self._server_2.debug_tensor_values["v:0:DebugIdentity"])
else:
self.assertEqual(0, len(self._server_1.debug_tensor_values))
self.assertEqual(0, len(self._server_2.debug_tensor_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:42,代码来源:session_debug_grpc_test.py
示例19: testMultiGPUSessionRun
def testMultiGPUSessionRun(self):
local_devices = device_lib.list_local_devices()
gpu_device_names = []
for device in local_devices:
if device.device_type == "GPU":
gpu_device_names.append(device.name)
gpu_device_names = sorted(gpu_device_names)
if len(gpu_device_names) < 2:
self.skipTest(
"This test requires at least 2 GPUs, but only %d is available." %
len(gpu_device_names))
with session.Session() as sess:
v = variables.Variable([10.0, 15.0], dtype=dtypes.float32, name="v")
with ops.device(gpu_device_names[0]):
u0 = math_ops.add(v, v, name="u0")
with ops.device(gpu_device_names[1]):
u1 = math_ops.multiply(v, v, name="u1")
w = math_ops.subtract(u1, u0, name="w")
sess.run(v.initializer)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_utils.watch_graph(run_options, sess.graph,
debug_urls="file://" + self._dump_root)
run_metadata = config_pb2.RunMetadata()
self.assertAllClose(
[80.0, 195.0],
sess.run(w, options=run_options, run_metadata=run_metadata))
debug_dump_dir = debug_data.DebugDumpDir(
self._dump_root, partition_graphs=run_metadata.partition_graphs)
self.assertEqual(3, len(debug_dump_dir.devices()))
self.assertAllClose(
[10.0, 15.0], debug_dump_dir.get_tensors("v", 0, "DebugIdentity")[0])
self.assertAllClose(
[20.0, 30.0], debug_dump_dir.get_tensors("u0", 0, "DebugIdentity")[0])
self.assertAllClose(
[100.0, 225.0],
debug_dump_dir.get_tensors("u1", 0, "DebugIdentity")[0])
开发者ID:1000sprites,项目名称:tensorflow,代码行数:41,代码来源:session_debug_multi_gpu_test.py
示例20: _compareOriginalAndReconstructedGraphDefs
def _compareOriginalAndReconstructedGraphDefs(self,
sess,
fetches,
feed_dict=None,
expected_output=None):
run_options = config_pb2.RunOptions(output_partition_graphs=True)
run_metadata = config_pb2.RunMetadata()
output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
run_metadata=run_metadata)
if expected_output is not None:
self.assertAllClose(expected_output, output)
non_debug_graph_defs = run_metadata.partition_graphs
debug_utils.watch_graph(
run_options, sess.graph, debug_urls=self._debug_url)
run_metadata = config_pb2.RunMetadata()
output = sess.run(fetches, feed_dict=feed_dict, options=run_options,
run_metadata=run_metadata)
if expected_output is not None:
self.assertAllClose(expected_output, output)
dump = debug_data.DebugDumpDir(
self._dump_dir, partition_graphs=run_metadata.partition_graphs,
validate=True)
reconstructed = dump.reconstructed_non_debug_partition_graphs()
self.assertEqual(len(non_debug_graph_defs), len(reconstructed))
for i, non_debug_graph_def in enumerate(non_debug_graph_defs):
device_name = debug_graphs._infer_device_name(non_debug_graph_def)
test_util.assert_equal_graph_def(
self._graphDefWithoutBlacklistedNodes(reconstructed[device_name]),
self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))
# Test debug_graphs.reconstruct_non_debug_graph_def.
reconstructed_again = (
debug_graphs.reconstruct_non_debug_graph_def(
run_metadata.partition_graphs[i]))
test_util.assert_equal_graph_def(
self._graphDefWithoutBlacklistedNodes(reconstructed_again),
self._graphDefWithoutBlacklistedNodes(non_debug_graph_def))
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:40,代码来源:debug_graph_reconstruction_test.py
注:本文中的tensorflow.python.debug.lib.debug_utils.watch_graph函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论