• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python monitored_session._HookedSession函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.training.monitored_session._HookedSession函数的典型用法代码示例。如果您正苦于以下问题:Python _HookedSession函数的具体用法?Python _HookedSession怎么用?Python _HookedSession使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了_HookedSession函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _validate_print_every_n_secs

  def _validate_print_every_n_secs(self, sess, at_end):
    t = tf.constant(42.0, name='foo')
    train_op = tf.constant(3)

    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_secs=1.0, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])
    sess.run(tf.global_variables_initializer())

    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # assertNotRegexpMatches is not supported by python 3.1 and later
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
    time.sleep(1.0)

    self._logger.logged_metric = []
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
开发者ID:Toyben,项目名称:models,代码行数:31,代码来源:metric_hook_test.py


示例2: test_log_tensors

  def test_log_tensors(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      tf.train.get_or_create_global_step()
      t1 = tf.constant(42.0, name='foo')
      t2 = tf.constant(43.0, name='bar')
      train_op = tf.constant(3)
      hook = metric_hook.LoggingMetricHook(
          tensors=[t1, t2], at_end=True, metric_logger=self._logger)
      hook.begin()
      mon_sess = monitored_session._HookedSession(sess, [hook])
      sess.run(tf.global_variables_initializer())

      for _ in range(3):
        mon_sess.run(train_op)
        self.assertEqual(self._logger.logged_metric, [])

      hook.end(sess)
      self.assertEqual(len(self._logger.logged_metric), 2)
      metric1 = self._logger.logged_metric[0]
      self.assertRegexpMatches(str(metric1["name"]), "foo")
      self.assertEqual(metric1["value"], 42.0)
      self.assertEqual(metric1["unit"], None)
      self.assertEqual(metric1["global_step"], 0)

      metric2 = self._logger.logged_metric[1]
      self.assertRegexpMatches(str(metric2["name"]), "bar")
      self.assertEqual(metric2["value"], 43.0)
      self.assertEqual(metric2["unit"], None)
      self.assertEqual(metric2["global_step"], 0)
开发者ID:Toyben,项目名称:models,代码行数:29,代码来源:metric_hook_test.py


示例3: test_step_counter_every_n_secs

    def test_step_counter_every_n_secs(self):
        with tf.Graph().as_default() as g, tf.Session() as sess:
            global_step = tf.contrib.framework.get_or_create_global_step()
            train_op = tf.assign_add(global_step, 1)
            summary_writer = testing.FakeSummaryWriter(self.log_dir, g)
            hook = tf.train.StepCounterHook(summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)

            hook.begin()
            sess.run(tf.initialize_all_variables())
            mon_sess = monitored_session._HookedSession(sess, [hook])
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            time.sleep(0.2)
            mon_sess.run(train_op)
            hook.end(sess)

            summary_writer.assert_summaries(
                test_case=self, expected_logdir=self.log_dir, expected_graph=g, expected_summaries={}
            )
            self.assertTrue(summary_writer.summaries, "No summaries were created.")
            self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
            for summary in summary_writer.summaries.values():
                summary_value = summary[0].value[0]
                self.assertEqual("global_step/sec", summary_value.tag)
                self.assertGreater(summary_value.simple_value, 0)
开发者ID:botonchou,项目名称:tensorflow,代码行数:26,代码来源:basic_session_run_hooks_test.py


示例4: _validate_print_every_n_steps

  def _validate_print_every_n_steps(self, sess, at_end):
    t = tf.constant(42.0, name="foo")

    train_op = tf.constant(3)
    hook = metric_hook.LoggingMetricHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end,
        metric_logger=self._logger)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.compat.v1.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    for _ in range(3):
      self._logger.logged_metric = []
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self._logger.logged_metric = []
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)

    self._logger.logged_metric = []
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self._logger.logged_metric), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self._logger.logged_metric).find(t.name), -1)
开发者ID:Exscotticus,项目名称:models,代码行数:34,代码来源:metric_hook_test.py


示例5: test_save_secs_saving_once_every_three_steps

  def test_save_secs_saving_once_every_three_steps(self, mock_time):
    mock_time.return_value = 1484695987.209386
    hook = basic_session_run_hooks.SummarySaverHook(
        save_secs=9.,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(8):
        mon_sess.run(self.train_op)
        mock_time.return_value += 3.1
      hook.end(sess)

    # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0
            },
            4: {
                'my_summary': 2.0
            },
            7: {
                'my_summary': 3.0
            },
        })
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:31,代码来源:basic_session_run_hooks_test.py


示例6: test_save_steps_saves_periodically

 def test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = tf.train.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with tf.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5, tf.contrib.framework.load_variable(
           self.model_dir, self.global_step.name))
开发者ID:KalraA,项目名称:tensorflow,代码行数:26,代码来源:basic_session_run_hooks_test.py


示例7: test_capture

  def test_capture(self):
    global_step = tf.contrib.framework.get_or_create_global_step()
    # Some test computation
    some_weights = tf.get_variable("weigths", [2, 128])
    computation = tf.nn.softmax(some_weights)

    hook = hooks.MetadataCaptureHook(
        params={"step": 5}, model_dir=self.model_dir,
        run_config=tf.contrib.learn.RunConfig())
    hook.begin()

    with self.test_session() as sess:
      sess.run(tf.global_variables_initializer())
      #pylint: disable=W0212
      mon_sess = monitored_session._HookedSession(sess, [hook])
      # Should not trigger for step 0
      sess.run(tf.assign(global_step, 0))
      mon_sess.run(computation)
      self.assertEqual(gfile.ListDirectory(self.model_dir), [])
      # Should trigger *after* step 5
      sess.run(tf.assign(global_step, 5))
      mon_sess.run(computation)
      self.assertEqual(gfile.ListDirectory(self.model_dir), [])
      mon_sess.run(computation)
      self.assertEqual(
          set(gfile.ListDirectory(self.model_dir)),
          set(["run_meta", "tfprof_log", "timeline.json"]))
开发者ID:AbhinavJain13,项目名称:seq2seq,代码行数:27,代码来源:hooks_test.py


示例8: _validate_print_every_n_steps

  def _validate_print_every_n_steps(self, sess, at_end):
    t = constant_op.constant(42.0, name='foo')

    train_op = constant_op.constant(3)
    hook = basic_session_run_hooks.LoggingTensorHook(
        tensors=[t.name], every_n_iter=10, at_end=at_end)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])
    sess.run(variables_lib.global_variables_initializer())
    mon_sess.run(train_op)
    self.assertRegexpMatches(str(self.logged_message), t.name)
    for _ in range(3):
      self.logged_message = ''
      for _ in range(9):
        mon_sess.run(train_op)
        # assertNotRegexpMatches is not supported by python 3.1 and later
        self.assertEqual(str(self.logged_message).find(t.name), -1)
      mon_sess.run(train_op)
      self.assertRegexpMatches(str(self.logged_message), t.name)

    # Add additional run to verify proper reset when called multiple times.
    self.logged_message = ''
    mon_sess.run(train_op)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    self.assertEqual(str(self.logged_message).find(t.name), -1)

    self.logged_message = ''
    hook.end(sess)
    if at_end:
      self.assertRegexpMatches(str(self.logged_message), t.name)
    else:
      # assertNotRegexpMatches is not supported by python 3.1 and later
      self.assertEqual(str(self.logged_message).find(t.name), -1)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:33,代码来源:basic_session_run_hooks_test.py


示例9: test_global_step_name

  def test_global_step_name(self):
    with ops.Graph().as_default() as g, session_lib.Session() as sess:
      with variable_scope.variable_scope('bar'):
        foo_step = variable_scope.get_variable(
            'foo',
            initializer=0,
            trainable=False,
            collections=[
                ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
            ])
      train_op = state_ops.assign_add(foo_step, 1)
      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
      hook = basic_session_run_hooks.StepCounterHook(
          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)

      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      mon_sess.run(train_op)
      mon_sess.run(train_op)
      hook.end(sess)

      summary_writer.assert_summaries(
          test_case=self,
          expected_logdir=self.log_dir,
          expected_graph=g,
          expected_summaries={})
      self.assertTrue(summary_writer.summaries, 'No summaries were created.')
      self.assertItemsEqual([2], summary_writer.summaries.keys())
      summary_value = summary_writer.summaries[2][0].value[0]
      self.assertEqual('bar/foo/sec', summary_value.tag)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:31,代码来源:basic_session_run_hooks_test.py


示例10: test_summary_saver

 def test_summary_saver(self):
   with tf.Graph().as_default() as g, tf.Session() as sess:
     log_dir = 'log/dir'
     summary_writer = testing.FakeSummaryWriter(log_dir, g)
     var = tf.Variable(0.0)
     tensor = tf.assign_add(var, 1.0)
     summary_op = tf.scalar_summary('my_summary', tensor)
     global_step = tf.contrib.framework.get_or_create_global_step()
     train_op = tf.assign_add(global_step, 1)
     hook = tf.train.SummarySaverHook(
         summary_op=summary_op, save_steps=8, summary_writer=summary_writer)
     hook.begin()
     sess.run(tf.initialize_all_variables())
     mon_sess = monitored_session._HookedSession(sess, [hook])
     for i in range(30):
       _ = i
       mon_sess.run(train_op)
     hook.end(sess)
     summary_writer.assert_summaries(
         test_case=self,
         expected_logdir=log_dir,
         expected_graph=g,
         expected_summaries={
             1: {'my_summary': 1.0},
             9: {'my_summary': 2.0},
             17: {'my_summary': 3.0},
             25: {'my_summary': 4.0},
         })
开发者ID:KalraA,项目名称:tensorflow,代码行数:28,代码来源:basic_session_run_hooks_test.py


示例11: test_stop_based_on_num_step

  def test_stop_based_on_num_step(self):
    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      no_op = control_flow_ops.no_op()
      h.begin()
      with session_lib.Session() as sess:
        mon_sess = monitored_session._HookedSession(sess, [h])
        sess.run(state_ops.assign(global_step, 5))
        h.after_create_session(sess, None)
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 13))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 14))
        mon_sess.run(no_op)
        self.assertFalse(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 15))
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
        sess.run(state_ops.assign(global_step, 16))
        mon_sess._should_stop = False
        mon_sess.run(no_op)
        self.assertTrue(mon_sess.should_stop())
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:26,代码来源:basic_session_run_hooks_test.py


示例12: test_multiple_summaries

  def test_multiple_summaries(self):
    hook = basic_session_run_hooks.SummarySaverHook(
        save_steps=8,
        summary_writer=self.summary_writer,
        summary_op=[self.summary_op, self.summary_op2])

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(10):
        mon_sess.run(self.train_op)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0,
                'my_summary2': 2.0
            },
            9: {
                'my_summary': 2.0,
                'my_summary2': 4.0
            },
        })
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:27,代码来源:basic_session_run_hooks_test.py


示例13: testDumpingDebugHookWithStatefulLegacyWatchFnWorks

  def testDumpingDebugHookWithStatefulLegacyWatchFnWorks(self):
    watch_fn_state = {"run_counter": 0}

    def counting_watch_fn(fetches, feed_dict):
      del fetches, feed_dict
      watch_fn_state["run_counter"] += 1
      if watch_fn_state["run_counter"] % 2 == 1:
        # If odd-index run (1-based), watch everything.
        return "DebugIdentity", r".*", r".*"
      else:
        # If even-index run, watch nothing.
        return "DebugIdentity", r"$^", r"$^"

    dumping_hook = hooks.DumpingDebugHook(
        self.session_root, watch_fn=counting_watch_fn, log_usage=False)
    mon_sess = monitored_session._HookedSession(self.sess, [dumping_hook])
    for _ in range(4):
      mon_sess.run(self.inc_v)

    dump_dirs = glob.glob(os.path.join(self.session_root, "run_*"))
    dump_dirs = sorted(
        dump_dirs, key=lambda x: int(os.path.basename(x).split("_")[1]))
    self.assertEqual(4, len(dump_dirs))

    for i, dump_dir in enumerate(dump_dirs):
      self._assert_correct_run_subdir_naming(os.path.basename(dump_dir))
      dump = debug_data.DebugDumpDir(dump_dir)
      if i % 2 == 0:
        self.assertAllClose([10.0 + 1.0 * i],
                            dump.get_tensors("v", 0, "DebugIdentity"))
      else:
        self.assertEqual(0, dump.size)

      self.assertEqual(repr(self.inc_v), dump.run_fetches_info)
      self.assertEqual(repr(None), dump.run_feed_keys_info)
开发者ID:aritratony,项目名称:tensorflow,代码行数:35,代码来源:dumping_wrapper_test.py


示例14: DISABLED_test_save_steps_saves_periodically

 def DISABLED_test_save_steps_saves_periodically(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=2, scaffold=self.scaffold)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # Not saved
       self.assertEqual(3,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
       mon_sess.run(self.train_op)
       # saved
       self.assertEqual(5,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
开发者ID:1000sprites,项目名称:tensorflow,代码行数:30,代码来源:basic_session_run_hooks_test.py


示例15: DISABLED_test_save_secs_calls_listeners_periodically

 def DISABLED_test_save_secs_calls_listeners_periodically(self):
   with self.graph.as_default():
     listener = MockCheckpointSaverListener()
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir,
         save_secs=2,
         scaffold=self.scaffold,
         listeners=[listener])
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)  # hook runs here
       mon_sess.run(self.train_op)
       time.sleep(2.5)
       mon_sess.run(self.train_op)  # hook runs here
       mon_sess.run(self.train_op)
       mon_sess.run(self.train_op)
       time.sleep(2.5)
       mon_sess.run(self.train_op)  # hook runs here
       mon_sess.run(self.train_op)  # hook won't run here, so it does at end
       hook.end(sess)  # hook runs here
     self.assertEqual({
         'begin': 1,
         'before_save': 4,
         'after_save': 4,
         'end': 1
     }, listener.get_counts())
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:29,代码来源:basic_session_run_hooks_test.py


示例16: testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks

  def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
    u = variables.Variable(2.1, name="u")
    v = variables.Variable(20.0, name="v")
    w = math_ops.multiply(u, v, name="w")

    sess = session.Session(config=no_rewrite_session_config())
    sess.run(variables.global_variables_initializer())

    grpc_debug_hook = hooks.TensorBoardDebugHook(
        ["localhost:%d" % self._server_port],
        send_traceback_and_source_code=False)
    sess = monitored_session._HookedSession(sess, [grpc_debug_hook])

    # Activate watch point on a tensor before calling sess.run().
    self._server.request_watch("u/read", 0, "DebugIdentity")
    self.assertAllClose(42.0, sess.run(w))

    # Check that the server has _not_ received any tracebacks, as a result of
    # the disabling above.
    with self.assertRaisesRegexp(
        ValueError, r"Op .*u/read.* does not exist"):
      self.assertTrue(self._server.query_op_traceback("u/read"))
    with self.assertRaisesRegexp(
        ValueError, r".* has not received any source file"):
      self._server.query_source_file_line(__file__, 1)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:25,代码来源:session_debug_grpc_test.py


示例17: test_save_secs_saving_once_every_three_steps

  def test_save_secs_saving_once_every_three_steps(self):
    hook = basic_session_run_hooks.SummarySaverHook(
        save_secs=0.9,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(variables_lib.global_variables_initializer())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(8):
        mon_sess.run(self.train_op)
        time.sleep(0.3)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {
                'my_summary': 1.0
            },
            4: {
                'my_summary': 2.0
            },
            7: {
                'my_summary': 3.0
            },
        })
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:29,代码来源:basic_session_run_hooks_test.py


示例18: testBothHooksAndUserHaveFeeds

  def testBothHooksAndUserHaveFeeds(self):
    with tf.Graph().as_default(), tf.Session() as sess:
      mock_hook = FakeHook()
      mock_hook2 = FakeHook()
      mon_sess = monitored_session._HookedSession(
          sess=sess, hooks=[mock_hook, mock_hook2])
      a_tensor = tf.constant([0], name='a_tensor')
      b_tensor = tf.constant([0], name='b_tensor')
      c_tensor = tf.constant([0], name='c_tensor')
      add_tensor = a_tensor + b_tensor + c_tensor
      mock_hook.request = tf.train.SessionRunArgs(
          None, feed_dict={
              a_tensor: [5]
          })
      mock_hook2.request = tf.train.SessionRunArgs(
          None, feed_dict={
              b_tensor: [10]
          })
      sess.run(tf.initialize_all_variables())

      feed_dict = {c_tensor: [20]}
      self.assertEqual(
          mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35])
      # User feed_dict should not be changed
      self.assertEqual(len(feed_dict), 1)
开发者ID:KalraA,项目名称:tensorflow,代码行数:25,代码来源:monitored_session_test.py


示例19: test_save_secs_saving_once_every_step

  def test_save_secs_saving_once_every_step(self):
    hook = tf.train.SummarySaverHook(
        save_steps=None,
        save_secs=0.5,
        summary_writer=self.summary_writer,
        summary_op=self.summary_op)

    with self.test_session() as sess:
      hook.begin()
      sess.run(tf.initialize_all_variables())
      mon_sess = monitored_session._HookedSession(sess, [hook])
      for _ in range(4):
        mon_sess.run(self.train_op)
        time.sleep(0.5)
      hook.end(sess)

    self.summary_writer.assert_summaries(
        test_case=self,
        expected_logdir=self.log_dir,
        expected_summaries={
            1: {'my_summary': 1.0},
            2: {'my_summary': 2.0},
            3: {'my_summary': 3.0},
            4: {'my_summary': 4.0},
        })
开发者ID:MrCrumpets,项目名称:tensorflow,代码行数:25,代码来源:basic_session_run_hooks_test.py


示例20: _validate_log_every_n_steps

  def _validate_log_every_n_steps(self, sess, every_n_steps, warm_steps):
    hook = hooks.ExamplesPerSecondHook(
        batch_size=256,
        every_n_steps=every_n_steps,
        warm_steps=warm_steps)
    hook.begin()
    mon_sess = monitored_session._HookedSession(sess, [hook])  # pylint: disable=protected-access
    sess.run(tf.global_variables_initializer())

    self.logged_message = ''
    for _ in range(every_n_steps):
      mon_sess.run(self.train_op)
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    # assertNotRegexpMatches is not supported by python 3.1 and later
    if global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    # Add additional run to verify proper reset when called multiple times.
    self.logged_message = ''
    mon_sess.run(self.train_op)
    global_step_val = sess.run(self.global_step)
    if every_n_steps == 1 and global_step_val > warm_steps:
      self.assertRegexpMatches(str(self.logged_message), 'exp/sec')
    else:
      self.assertEqual(str(self.logged_message).find('exp/sec'), -1)

    hook.end(sess)
开发者ID:Yogurtla,项目名称:models,代码行数:32,代码来源:hooks_test.py



注:本文中的tensorflow.python.training.monitored_session._HookedSession函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap