• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python session_debug_testlib.no_rewrite_session_config函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.debug.lib.session_debug_testlib.no_rewrite_session_config函数的典型用法代码示例。如果您正苦于以下问题:Python no_rewrite_session_config函数的具体用法?Python no_rewrite_session_config怎么用?Python no_rewrite_session_config使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了no_rewrite_session_config函数的19个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testSendingLargeStringTensorWorks

  def testSendingLargeStringTensorWorks(self):
    with self.test_session(
        use_gpu=True,
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      strs_total_size_threshold = 5000 * 1024
      cum_size = 0
      u_init_val_array = []
      while cum_size < strs_total_size_threshold:
        strlen = np.random.randint(200)
        u_init_val_array.append(b"A" * strlen)
        cum_size += strlen

      u_init = constant_op.constant(
          u_init_val_array, dtype=dtypes.string, name="u_init")
      u = variables.Variable(u_init, name="u")

      def watch_fn(fetches, feeds):
        del fetches, feeds
        return framework.WatchOptions(
            debug_ops=["DebugIdentity"],
            node_name_regex_whitelist=r"u_init")
      sess = grpc_wrapper.GrpcDebugWrapperSession(
          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
      sess.run(u.initializer)

      self.assertAllEqual(
          u_init_val_array,
          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:28,代码来源:grpc_large_data_test.py


示例2: testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks

  def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
    u = variables.VariableV1(2.1, name="u")
    v = variables.VariableV1(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(
        config=session_debug_testlib.no_rewrite_session_config())
    sess.run(variables.global_variables_initializer())

    grpc_debug_hook = hooks.TensorBoardDebugHook(
        ["localhost:%d" % self._server_port],
        send_traceback_and_source_code=False)
    sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

    # Activate watch point on a tensor before calling sess.run().
    self._server.request_watch("u/read", 0, "DebugIdentity")
    self.assertAllClose(42.0, sess.run(w))

    # Check that the server has _not_ received any tracebacks, as a result of
    # the disabling above.
    with self.assertRaisesRegexp(
        ValueError, r"Op .*u/read.* does not exist"):
      self.assertTrue(self._server.query_op_traceback("u/read"))
    with self.assertRaisesRegexp(
        ValueError, r".* has not received any source file"):
      self._server.query_source_file_line(__file__, 1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:26,代码来源:session_debug_grpc_test.py


示例3: testGrpcDebugWrapperSessionWithWatchFnWorks

  def testGrpcDebugWrapperSessionWithWatchFnWorks(self):
    def watch_fn(feeds, fetch_keys):
      del feeds, fetch_keys
      return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None

    u = variables.VariableV1(2.1, name="u")
    v = variables.VariableV1(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(
        config=session_debug_testlib.no_rewrite_session_config())
    sess.run(u.initializer)
    sess.run(v.initializer)

    sess = grpc_wrapper.GrpcDebugWrapperSession(
        sess, "localhost:%d" % self._server_port, watch_fn=watch_fn)
    w_result = sess.run(w)
    self.assertAllClose(42.0, w_result)

    dump = debug_data.DebugDumpDir(self._dump_root)
    self.assertEqual(4, dump.size)
    self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    self.assertEqual(
        14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
    self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
    self.assertEqual(
        14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:27,代码来源:session_debug_grpc_test.py


示例4: testSendingLargeGraphDefsWorks

  def testSendingLargeGraphDefsWorks(self):
    with self.test_session(
        use_gpu=True,
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      u = variables.Variable(42.0, name="original_u")
      for _ in xrange(50 * 1000):
        u = array_ops.identity(u)
      sess.run(variables.global_variables_initializer())

      def watch_fn(fetches, feeds):
        del fetches, feeds
        return framework.WatchOptions(
            debug_ops=["DebugIdentity"],
            node_name_regex_whitelist=r"original_u")
      sess = grpc_wrapper.GrpcDebugWrapperSession(
          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
      self.assertAllClose(42.0, sess.run(u))

      self.assertAllClose(
          [42.0],
          self.debug_server.debug_tensor_values["original_u:0:DebugIdentity"])
      self.assertEqual(2 if test.is_gpu_available() else 1,
                       len(self.debug_server.partition_graph_defs))
      max_graph_def_size = max([
          len(graph_def.SerializeToString())
          for graph_def in self.debug_server.partition_graph_defs])
      self.assertGreater(max_graph_def_size, 4 * 1024 * 1024)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:27,代码来源:grpc_large_data_test.py


示例5: testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks

  def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run(variables.global_variables_initializer())

      # Disable the sending of traceback and source code.
      sess = grpc_wrapper.TensorBoardDebugWrapperSession(
          sess, self._debug_server_url_1, send_traceback_and_source_code=False)

      for i in xrange(4):
        self._server_1.clear_data()

        if i == 0:
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)

        output = sess.run([inc_v_1, inc_v_2])
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        # No op traceback or source code should have been received by the debug
        # server due to the disabling above.
        with self.assertRaisesRegexp(
            ValueError, r"Op .*delta_1.* does not exist"):
          self.assertTrue(self._server_1.query_op_traceback("delta_1"))
        with self.assertRaisesRegexp(
            ValueError, r".* has not received any source file"):
          self._server_1.query_source_file_line(__file__, 1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:34,代码来源:session_debug_grpc_test.py


示例6: testToggleBreakpointsWorks

  def testToggleBreakpointsWorks(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[self._debug_server_url_1])

      for i in xrange(4):
        self._server_1.clear_data()

        if i in (0, 2):
          # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)
          self._server_1.request_watch(
              "delta_2", 0, "DebugIdentity", breakpoint=True)
        else:
          # Disable the breakpoint in runs 1 and 3.
          self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")

        output = sess.run([inc_v_1, inc_v_2],
                          options=run_options, run_metadata=run_metadata)
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        if i in (0, 2):
          # During runs 0 and 2, the server should have received the published
          # debug tensor delta:0:DebugIdentity. The breakpoint should have been
          # unblocked by EventReply reponses from the server.
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
          # After the runs, the server should have properly registered the
          # breakpoints due to the request_unwatch calls.
          self.assertSetEqual({("delta_1", 0, "DebugIdentity"),
                               ("delta_2", 0, "DebugIdentity")},
                              self._server_1.breakpoints)
        else:
          # After the end of runs 1 and 3, the server has received the requests
          # to disable the breakpoint at delta:0:DebugIdentity.
          self.assertSetEqual(set(), self._server_1.breakpoints)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:57,代码来源:session_debug_grpc_test.py


示例7: testToggleWatchesOnCoreMetadata

  def testToggleWatchesOnCoreMetadata(self):
    (_, debug_server_url, _, server_thread,
     server) = grpc_debug_test_server.start_server_on_separate_thread(
         dump_to_filesystem=False,
         toggle_watch_on_core_metadata=[("toggled_1", 0, "DebugIdentity"),
                                        ("toggled_2", 0, "DebugIdentity")])
    self._servers_and_threads.append((server, server_thread))

    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_1")
      # These two nodes have names that match those in the
      # toggle_watch_on_core_metadata argument used when calling
      # start_server_on_separate_thread().
      toggled_1 = constant_op.constant(5.0, name="toggled_1")
      toggled_2 = constant_op.constant(-5.0, name="toggled_2")
      inc_v_1 = state_ops.assign_add(v_1, toggled_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, toggled_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[debug_server_url])

      for i in xrange(4):
        server.clear_data()

        sess.run([inc_v_1, inc_v_2],
                 options=run_options, run_metadata=run_metadata)

        if i % 2 == 0:
          self.assertEqual(2, len(server.debug_tensor_values))
          self.assertAllClose(
              [5.0],
              server.debug_tensor_values["toggled_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              server.debug_tensor_values["toggled_2:0:DebugIdentity"])
        else:
          self.assertEqual(0, len(server.debug_tensor_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:46,代码来源:session_debug_grpc_test.py


示例8: testTensorBoardDebugHookWorks

  def testTensorBoardDebugHookWorks(self):
    u = variables.VariableV1(2.1, name="u")
    v = variables.VariableV1(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(
        config=session_debug_testlib.no_rewrite_session_config())
    sess.run(u.initializer)
    sess.run(v.initializer)

    grpc_debug_hook = hooks.TensorBoardDebugHook(
        ["localhost:%d" % self._server_port])
    sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

    # Activate watch point on a tensor before calling sess.run().
    self._server.request_watch("u/read", 0, "DebugIdentity")
    self.assertAllClose(42.0, sess.run(w))

    # self.assertAllClose(42.0, sess.run(w))
    dump = debug_data.DebugDumpDir(self._dump_root)
    self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))

    # Check that the server has received the stack trace.
    self.assertTrue(self._server.query_op_traceback("u"))
    self.assertTrue(self._server.query_op_traceback("u/read"))
    self.assertTrue(self._server.query_op_traceback("v"))
    self.assertTrue(self._server.query_op_traceback("v/read"))
    self.assertTrue(self._server.query_op_traceback("w"))

    # Check that the server has received the python file content.
    # Query an arbitrary line to make sure that is the case.
    with open(__file__, "rt") as this_source_file:
      first_line = this_source_file.readline().strip()
      self.assertEqual(
          first_line, self._server.query_source_file_line(__file__, 1))

    self._server.clear_data()
    # Call sess.run() again, and verify that this time the traceback and source
    # code is not sent, because the graph version is not newer.
    self.assertAllClose(42.0, sess.run(w))
    with self.assertRaises(ValueError):
      self._server.query_op_traceback("delta_1")
    with self.assertRaises(ValueError):
      self._server.query_source_file_line(__file__, 1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:44,代码来源:session_debug_grpc_test.py


示例9: testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers

  def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v = variables.VariableV1(50.0, name="v")
      delta = constant_op.constant(5.0, name="delta")
      inc_v = state_ops.assign_add(v, delta, name="inc_v")

      sess.run(v.initializer)

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)"],
          debug_urls=[self._debug_server_url_1, self._debug_server_url_2])

      for i in xrange(4):
        self._server_1.clear_data()
        self._server_2.clear_data()

        if i % 2 == 0:
          self._server_1.request_watch("delta", 0, "DebugIdentity")
          self._server_2.request_watch("v", 0, "DebugIdentity")
        else:
          self._server_1.request_unwatch("delta", 0, "DebugIdentity")
          self._server_2.request_unwatch("v", 0, "DebugIdentity")

        sess.run(inc_v, options=run_options, run_metadata=run_metadata)

        if i % 2 == 0:
          self.assertEqual(1, len(self._server_1.debug_tensor_values))
          self.assertEqual(1, len(self._server_2.debug_tensor_values))
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta:0:DebugIdentity"])
          self.assertAllClose(
              [50 + 5.0 * i],
              self._server_2.debug_tensor_values["v:0:DebugIdentity"])
        else:
          self.assertEqual(0, len(self._server_1.debug_tensor_values))
          self.assertEqual(0, len(self._server_2.debug_tensor_values))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:42,代码来源:session_debug_grpc_test.py


示例10: testGrpcDebugHookWithStatelessWatchFnWorks

  def testGrpcDebugHookWithStatelessWatchFnWorks(self):
    # Perform some set up. Specifically, construct a simple TensorFlow graph and
    # create a watch function for certain ops.
    def watch_fn(feeds, fetch_keys):
      del feeds, fetch_keys
      return framework.WatchOptions(
          debug_ops=["DebugIdentity", "DebugNumericSummary"],
          node_name_regex_whitelist=r".*/read",
          op_type_regex_whitelist=None,
          tolerate_debug_op_creation_failures=True)

    u = variables.VariableV1(2.1, name="u")
    v = variables.VariableV1(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(
        config=session_debug_testlib.no_rewrite_session_config())
    sess.run(u.initializer)
    sess.run(v.initializer)

    # Create a hook. One could use this hook with say a tflearn Estimator.
    # However, we use a HookedSession in this test to avoid depending on the
    # internal implementation of Estimators.
    grpc_debug_hook = hooks.GrpcDebugHook(
        ["localhost:%d" % self._server_port], watch_fn=watch_fn)
    sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

    # Run the hooked session. This should stream tensor data to the GRPC
    # endpoints.
    w_result = sess.run(w)

    # Verify that the hook monitored the correct tensors.
    self.assertAllClose(42.0, w_result)
    dump = debug_data.DebugDumpDir(self._dump_root)
    self.assertEqual(4, dump.size)
    self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    self.assertEqual(
        14, len(dump.get_tensors("u/read", 0, "DebugNumericSummary")[0]))
    self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
    self.assertEqual(
        14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:41,代码来源:session_debug_grpc_test.py


示例11: testSendingEmptyStringTensorWorks

  def testSendingEmptyStringTensorWorks(self):
    with self.test_session(
        use_gpu=True,
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      u_init = constant_op.constant(
          [], dtype=dtypes.string, shape=[0], name="u_init")
      u = variables.Variable(u_init, name="u")

      def watch_fn(fetches, feeds):
        del fetches, feeds
        return framework.WatchOptions(
            debug_ops=["DebugIdentity"],
            node_name_regex_whitelist=r"u_init")
      sess = grpc_wrapper.GrpcDebugWrapperSession(
          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
      sess.run(u.initializer)

      u_init_value = self.debug_server.debug_tensor_values[
          "u_init:0:DebugIdentity"][0]
      self.assertEqual(np.object, u_init_value.dtype)
      self.assertEqual(0, len(u_init_value))
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:21,代码来源:grpc_large_data_test.py


示例12: testSendingStringTensorWithAlmostTooLargeStringsWorks

  def testSendingStringTensorWithAlmostTooLargeStringsWorks(self):
    with self.test_session(
        use_gpu=True,
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      u_init_val = [
          b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
      u_init = constant_op.constant(
          u_init_val, dtype=dtypes.string, name="u_init")
      u = variables.Variable(u_init, name="u")

      def watch_fn(fetches, feeds):
        del fetches, feeds
        return framework.WatchOptions(
            debug_ops=["DebugIdentity"],
            node_name_regex_whitelist=r"u_init")
      sess = grpc_wrapper.GrpcDebugWrapperSession(
          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
      sess.run(u.initializer)

      self.assertAllEqual(
          u_init_val,
          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:22,代码来源:grpc_large_data_test.py


示例13: testGrpcDebugWrapperSessionWithoutWatchFnWorks

  def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
    u = variables.VariableV1(2.1, name="u")
    v = variables.VariableV1(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(
        config=session_debug_testlib.no_rewrite_session_config())
    sess.run(u.initializer)
    sess.run(v.initializer)

    sess = grpc_wrapper.GrpcDebugWrapperSession(
        sess, "localhost:%d" % self._server_port)
    w_result = sess.run(w)
    self.assertAllClose(42.0, w_result)

    dump = debug_data.DebugDumpDir(self._dump_root)
    self.assertEqual(5, dump.size)
    self.assertAllClose([2.1], dump.get_tensors("u", 0, "DebugIdentity"))
    self.assertAllClose([2.1], dump.get_tensors("u/read", 0, "DebugIdentity"))
    self.assertAllClose([20.0], dump.get_tensors("v", 0, "DebugIdentity"))
    self.assertAllClose([20.0], dump.get_tensors("v/read", 0, "DebugIdentity"))
    self.assertAllClose([42.0], dump.get_tensors("w", 0, "DebugIdentity"))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:22,代码来源:session_debug_grpc_test.py


示例14: testSendingLargeFloatTensorWorks

  def testSendingLargeFloatTensorWorks(self):
    with self.test_session(
        use_gpu=True,
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      u_init_val_array = list(xrange(1200 * 1024))
      # Size: 4 * 1200 * 1024 = 4800k > 4M

      u_init = constant_op.constant(
          u_init_val_array, dtype=dtypes.float32, name="u_init")
      u = variables.Variable(u_init, name="u")

      def watch_fn(fetches, feeds):
        del fetches, feeds  # Unused by this watch_fn.
        return framework.WatchOptions(
            debug_ops=["DebugIdentity"],
            node_name_regex_whitelist=r"u_init")
      sess = grpc_wrapper.GrpcDebugWrapperSession(
          sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn)
      sess.run(u.initializer)

      self.assertAllEqual(
          u_init_val_array,
          self.debug_server.debug_tensor_values["u_init:0:DebugIdentity"][0])
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:23,代码来源:grpc_large_data_test.py


示例15: testAllowsDifferentWatchesOnDifferentRuns

  def testAllowsDifferentWatchesOnDifferentRuns(self):
    """Test watching different tensors on different runs of the same graph."""

    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      u_init_val = [[5.0, 3.0], [-1.0, 0.0]]
      v_init_val = [[2.0], [-1.0]]

      # Use node names with overlapping namespace (i.e., parent directory) to
      # test concurrent, non-racing directory creation.
      u_name = "diff_Watch/u"
      v_name = "diff_Watch/v"

      u_init = constant_op.constant(u_init_val, shape=[2, 2])
      u = variables.VariableV1(u_init, name=u_name)
      v_init = constant_op.constant(v_init_val, shape=[2, 1])
      v = variables.VariableV1(v_init, name=v_name)

      w = math_ops.matmul(u, v, name="diff_Watch/matmul")

      u.initializer.run()
      v.initializer.run()

      for i in range(2):
        run_options = config_pb2.RunOptions(output_partition_graphs=True)

        run_dump_root = self._debug_dump_dir(run_number=i)
        debug_urls = self._debug_urls(run_number=i)

        if i == 0:
          # First debug run: Add debug tensor watch for u.
          debug_utils.add_debug_tensor_watch(
              run_options, "%s/read" % u_name, 0, debug_urls=debug_urls)
        else:
          # Second debug run: Add debug tensor watch for v.
          debug_utils.add_debug_tensor_watch(
              run_options, "%s/read" % v_name, 0, debug_urls=debug_urls)

        run_metadata = config_pb2.RunMetadata()

        # Invoke Session.run().
        sess.run(w, options=run_options, run_metadata=run_metadata)

        self.assertEqual(self._expected_partition_graph_count,
                         len(run_metadata.partition_graphs))

        dump = debug_data.DebugDumpDir(
            run_dump_root, partition_graphs=run_metadata.partition_graphs)
        self.assertTrue(dump.loaded_partition_graphs())

        # Each run should have generated only one dumped tensor, not two.
        self.assertEqual(1, dump.size)

        if i == 0:
          self.assertAllClose([u_init_val],
                              dump.get_tensors("%s/read" % u_name, 0,
                                               "DebugIdentity"))
          self.assertGreaterEqual(
              dump.get_rel_timestamps("%s/read" % u_name, 0,
                                      "DebugIdentity")[0], 0)
        else:
          self.assertAllClose([v_init_val],
                              dump.get_tensors("%s/read" % v_name, 0,
                                               "DebugIdentity"))
          self.assertGreaterEqual(
              dump.get_rel_timestamps("%s/read" % v_name, 0,
                                      "DebugIdentity")[0], 0)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:67,代码来源:session_debug_file_test.py


示例16: testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException

 def testUseInvalidWatchFnTypeWithGrpcDebugWrapperSessionRaisesException(self):
   sess = session.Session(
       config=session_debug_testlib.no_rewrite_session_config())
   with self.assertRaises(TypeError):
     grpc_wrapper.GrpcDebugWrapperSession(
         sess, "localhost:%d" % self._server_port, watch_fn="foo")
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:6,代码来源:session_debug_grpc_test.py


示例17: testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2

 def testConstructGrpcDebugWrapperSessionWithInvalidTypeRaisesException2(self):
   sess = session.Session(
       config=session_debug_testlib.no_rewrite_session_config())
   with self.assertRaisesRegexp(
       TypeError, "Expected type str in list grpc_debug_server_addresses"):
     grpc_wrapper.GrpcDebugWrapperSession(sess, ["localhost:1337", 1338])
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:6,代码来源:session_debug_grpc_test.py


示例18: testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes

  def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_1")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      run_metadata = config_pb2.RunMetadata()
      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(
          run_options,
          sess.graph,
          debug_ops=["DebugIdentity(gated_grpc=true)",
                     "DebugNumericSummary(gated_grpc=true)"],
          debug_urls=[self._debug_server_url_1])

      for i in xrange(4):
        self._server_1.clear_data()

        if i % 2 == 0:
          self._server_1.request_watch("delta_1", 0, "DebugIdentity")
          self._server_1.request_watch("delta_2", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_1", 0, "DebugNumericSummary")
          self._server_1.request_unwatch("delta_2", 0, "DebugNumericSummary")
        else:
          self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")
          self._server_1.request_watch("delta_1", 0, "DebugNumericSummary")
          self._server_1.request_watch("delta_2", 0, "DebugNumericSummary")

        sess.run([inc_v_1, inc_v_2],
                 options=run_options, run_metadata=run_metadata)

        # Watched debug tensors are:
        #   Run 0: delta_[1,2]:0:DebugIdentity
        #   Run 1: delta_[1,2]:0:DebugNumericSummary
        #   Run 2: delta_[1,2]:0:DebugIdentity
        #   Run 3: delta_[1,2]:0:DebugNumericSummary
        self.assertEqual(2, len(self._server_1.debug_tensor_values))
        if i % 2 == 0:
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
        else:
          self.assertAllClose(
              [[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 5.0, 5.0, 5.0,
                0.0, 1.0, 0.0]],
              self._server_1.debug_tensor_values[
                  "delta_1:0:DebugNumericSummary"])
          self.assertAllClose(
              [[1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -5.0, -5.0, -5.0,
                0.0, 1.0, 0.0]],
              self._server_1.debug_tensor_values[
                  "delta_2:0:DebugNumericSummary"])
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:62,代码来源:session_debug_grpc_test.py


示例19: testTensorBoardDebuggerWrapperToggleBreakpointsWorks

  def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
    with session.Session(
        config=session_debug_testlib.no_rewrite_session_config()) as sess:
      v_1 = variables.VariableV1(50.0, name="v_1")
      v_2 = variables.VariableV1(-50.0, name="v_2")
      delta_1 = constant_op.constant(5.0, name="delta_1")
      delta_2 = constant_op.constant(-5.0, name="delta_2")
      inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
      inc_v_2 = state_ops.assign_add(v_2, delta_2, name="inc_v_2")

      sess.run([v_1.initializer, v_2.initializer])

      # The TensorBoardDebugWrapperSession should add a DebugIdentity debug op
      # with attribute gated_grpc=True for every tensor in the graph.
      sess = grpc_wrapper.TensorBoardDebugWrapperSession(
          sess, self._debug_server_url_1)

      for i in xrange(4):
        self._server_1.clear_data()

        if i in (0, 2):
          # Enable breakpoint at delta_[1,2]:0:DebugIdentity in runs 0 and 2.
          self._server_1.request_watch(
              "delta_1", 0, "DebugIdentity", breakpoint=True)
          self._server_1.request_watch(
              "delta_2", 0, "DebugIdentity", breakpoint=True)
        else:
          # Disable the breakpoint in runs 1 and 3.
          self._server_1.request_unwatch("delta_1", 0, "DebugIdentity")
          self._server_1.request_unwatch("delta_2", 0, "DebugIdentity")

        output = sess.run([inc_v_1, inc_v_2])
        self.assertAllClose([50.0 + 5.0 * (i + 1), -50 - 5.0 * (i + 1)], output)

        if i in (0, 2):
          # During runs 0 and 2, the server should have received the published
          # debug tensor delta:0:DebugIdentity. The breakpoint should have been
          # unblocked by EventReply reponses from the server.
          self.assertAllClose(
              [5.0],
              self._server_1.debug_tensor_values["delta_1:0:DebugIdentity"])
          self.assertAllClose(
              [-5.0],
              self._server_1.debug_tensor_values["delta_2:0:DebugIdentity"])
          # After the runs, the server should have properly registered the
          # breakpoints.
        else:
          # After the end of runs 1 and 3, the server has received the requests
          # to disable the breakpoint at delta:0:DebugIdentity.
          self.assertSetEqual(set(), self._server_1.breakpoints)

        if i == 0:
          # Check that the server has received the stack trace.
          self.assertTrue(self._server_1.query_op_traceback("delta_1"))
          self.assertTrue(self._server_1.query_op_traceback("delta_2"))
          self.assertTrue(self._server_1.query_op_traceback("inc_v_1"))
          self.assertTrue(self._server_1.query_op_traceback("inc_v_2"))
          # Check that the server has received the python file content.
          # Query an arbitrary line to make sure that is the case.
          with open(__file__, "rt") as this_source_file:
            first_line = this_source_file.readline().strip()
          self.assertEqual(
              first_line, self._server_1.query_source_file_line(__file__, 1))
        else:
          # In later Session.run() calls, the traceback shouldn't have been sent
          # because it is already sent in the 1st call. So calling
          # query_op_traceback() should lead to an exception, because the test
          # debug server clears the data at the beginning of every iteration.
          with self.assertRaises(ValueError):
            self._server_1.query_op_traceback("delta_1")
          with self.assertRaises(ValueError):
            self._server_1.query_source_file_line(__file__, 1)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:72,代码来源:session_debug_grpc_test.py



注:本文中的tensorflow.python.debug.lib.session_debug_testlib.no_rewrite_session_config函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap