本文整理汇总了Python中tensorflow.contrib.learn.python.learn.monitored_session._HookedSession函数的典型用法代码示例。如果您正苦于以下问题:Python _HookedSession函数的具体用法?Python _HookedSession怎么用?Python _HookedSession使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了_HookedSession函数的14个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_summary_saver
def test_summary_saver(self):
with tf.Graph().as_default() as g, tf.Session() as sess:
log_dir = 'log/dir'
summary_writer = testing.FakeSummaryWriter(log_dir, g)
var = tf.Variable(0.0)
tensor = tf.assign_add(var, 1.0)
summary_op = tf.scalar_summary('my_summary', tensor)
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.assign_add(global_step, 1)
hook = basic_session_run_hooks.SummarySaverHook(
summary_op=summary_op, save_steps=8, summary_writer=summary_writer)
hook.begin()
sess.run(tf.initialize_all_variables())
mon_sess = monitored_session._HookedSession(sess, [hook])
for i in range(30):
_ = i
mon_sess.run(train_op)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self,
expected_logdir=log_dir,
expected_graph=g,
expected_summaries={
1: {'my_summary': 1.0},
9: {'my_summary': 2.0},
17: {'my_summary': 3.0},
25: {'my_summary': 4.0},
})
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:28,代码来源:basic_session_run_hooks_test.py
示例2: test_save_steps_saves_periodically
def test_save_steps_saves_periodically(self):
with self.graph.as_default():
hook = basic_session_run_hooks.CheckpointSaverHook(
self.model_dir, save_steps=2, scaffold=self.scaffold)
hook.begin()
self.scaffold.finalize()
with tf.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
mon_sess.run(self.train_op)
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(1, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(3, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
mon_sess.run(self.train_op)
# Not saved
self.assertEqual(3, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
mon_sess.run(self.train_op)
# saved
self.assertEqual(5, tf.contrib.framework.load_variable(
self.model_dir, self.global_step.name))
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:26,代码来源:basic_session_run_hooks_test.py
示例3: testShouldStop
def testShouldStop(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
tf.constant([0], name='a_tensor')
sess.run(tf.initialize_all_variables())
mon_sess.run(fetches='a_tensor')
self.assertFalse(mon_sess.should_stop())
mock_hook.should_stop = True
mon_sess.run(fetches='a_tensor')
self.assertTrue(mon_sess.should_stop())
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:15,代码来源:monitored_session_test.py
示例4: testOnlyHooksHaveFeeds
def testOnlyHooksHaveFeeds(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = tf.constant([0], name='a_tensor')
b_tensor = tf.constant([0], name='b_tensor')
add_tensor = a_tensor + b_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={b_tensor: [10]})
sess.run(tf.initialize_all_variables())
self.assertEqual(mon_sess.run(fetches=add_tensor), [15])
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:16,代码来源:monitored_session_test.py
示例5: testRunPassesAllArguments
def testRunPassesAllArguments(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_run = FakeSession(sess)
mon_sess = monitored_session._HookedSession(sess=mock_run, hooks=[])
a_tensor = tf.constant([0], name='a_tensor')
sess.run(tf.initialize_all_variables())
output = mon_sess.run(fetches=a_tensor,
feed_dict='a_feed',
options='an_option',
run_metadata='a_metadata')
self.assertEqual(output, [0])
self.assertEqual(mock_run.args_called, {
'feed_dict': 'a_feed',
'options': 'an_option',
'run_metadata': 'a_metadata'
})
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:16,代码来源:monitored_session_test.py
示例6: testHooksAndUserFeedConflicts
def testHooksAndUserFeedConflicts(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = tf.constant([0], name='a_tensor')
b_tensor = tf.constant([0], name='b_tensor')
add_tensor = a_tensor + b_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={b_tensor: [10]})
sess.run(tf.initialize_all_variables())
with self.assertRaisesRegexp(RuntimeError, 'Same tensor is fed'):
mon_sess.run(fetches=add_tensor, feed_dict={b_tensor: [10]})
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:17,代码来源:monitored_session_test.py
示例7: testFetchesHookRequests
def testFetchesHookRequests(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = tf.constant([0], name='a_tensor')
another_tensor = tf.constant([5], name='another_tensor')
third_tensor = tf.constant([10], name='third_tensor')
mock_hook.request = session_run_hook.SessionRunArgs([another_tensor])
mock_hook2.request = session_run_hook.SessionRunArgs([third_tensor])
sess.run(tf.initialize_all_variables())
output = mon_sess.run(fetches=a_tensor)
self.assertEqual(output, [0])
self.assertEqual(mock_hook.last_run_values.results, [5])
self.assertEqual(mock_hook2.last_run_values.results, [10])
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:17,代码来源:monitored_session_test.py
示例8: testCallsHooksBeginEnd
def testCallsHooksBeginEnd(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = tf.constant([0], name='a_tensor')
sess.run(tf.initialize_all_variables())
mon_sess.run(a_tensor)
for hook in [mock_hook, mock_hook2]:
self.assertEqual(
hook.last_run_values,
session_run_hook.SessionRunValues(results=None))
self.assertEqual(hook.last_run_context.original_args,
session_run_hook.SessionRunArgs(a_tensor))
self.assertEqual(hook.last_run_context.session, sess)
self.assertEqual(hook.call_counter['before_run'], 1)
self.assertEqual(hook.call_counter['after_run'], 1)
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:19,代码来源:monitored_session_test.py
示例9: testBothHooksAndUserHaveFeeds
def testBothHooksAndUserHaveFeeds(self):
with tf.Graph().as_default(), tf.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = tf.constant([0], name='a_tensor')
b_tensor = tf.constant([0], name='b_tensor')
c_tensor = tf.constant([0], name='c_tensor')
add_tensor = a_tensor + b_tensor + c_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={b_tensor: [10]})
sess.run(tf.initialize_all_variables())
feed_dict = {c_tensor: [20]}
self.assertEqual(
mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35])
# User feed_dict should not be changed
self.assertEqual(len(feed_dict), 1)
开发者ID:JamesFysh,项目名称:tensorflow,代码行数:21,代码来源:monitored_session_test.py
示例10: test_stop_based_on_last_step
def test_stop_based_on_last_step(self):
h = basic_session_run_hooks.StopAtStepHook(last_step=10)
with tf.Graph().as_default():
global_step = tf.contrib.framework.get_or_create_global_step()
no_op = tf.no_op()
h.begin()
with tf.Session() as sess:
mon_sess = monitored_session._HookedSession(sess, [h])
sess.run(tf.assign(global_step, 5))
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(tf.assign(global_step, 9))
mon_sess.run(no_op)
self.assertFalse(mon_sess.should_stop())
sess.run(tf.assign(global_step, 10))
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
sess.run(tf.assign(global_step, 11))
mon_sess._should_stop = False
mon_sess.run(no_op)
self.assertTrue(mon_sess.should_stop())
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:21,代码来源:basic_session_run_hooks_test.py
示例11: test_print
def test_print(self):
with tf.Graph().as_default(), tf.Session() as sess:
t = tf.constant(42.0, name='foo')
train_op = tf.constant(3)
hook = basic_session_run_hooks.LoggingTensorHook(tensors=[t.name],
every_n_iter=10)
hook.begin()
mon_sess = monitored_session._HookedSession(sess, [hook])
sess.run(tf.initialize_all_variables())
mon_sess.run(train_op)
self.assertRegexpMatches(str(self.logged_message), t.name)
for j in range(3):
_ = j
self.logged_message = ''
for i in range(9):
_ = i
mon_sess.run(train_op)
# assertNotRegexpMatches is not supported by python 3.1 and later
self.assertEqual(str(self.logged_message).find(t.name), -1)
mon_sess.run(train_op)
self.assertRegexpMatches(str(self.logged_message), t.name)
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:21,代码来源:basic_session_run_hooks_test.py
示例12: test_calls_and_steps
def test_calls_and_steps(self):
with tf.Graph().as_default(), tf.Session() as sess:
global_step_tensor = tf.contrib.framework.create_global_step()
inc_5 = tf.assign_add(global_step_tensor, 5)
mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor()
hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
hook.begin()
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.call_counter['begin'], 1)
sess.run(tf.initialize_all_variables())
sess.run(global_step_tensor.assign(10))
mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
mon_sess.run(inc_5)
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.output, {})
self.assertEqual(mon.last_begin_step, 11)
self.assertEqual(mon.last_end_step, 11)
self.assertEqual(mon.last_post_step, 11)
self.assertEqual(mon.call_counter['step_end'], 1)
self.assertEqual(mon.call_counter['step_begin'], 1)
self.assertEqual(mon.call_counter['post_step'], 1)
mon_sess.run(inc_5)
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.output, {})
self.assertEqual(mon.last_begin_step, 16)
self.assertEqual(mon.last_end_step, 16)
self.assertEqual(mon.last_post_step, 16)
self.assertEqual(mon.call_counter['step_end'], 2)
self.assertEqual(mon.call_counter['step_begin'], 2)
self.assertEqual(mon.call_counter['post_step'], 2)
hook.end(sess)
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.call_counter['end'], 1)
开发者ID:MostafaGazar,项目名称:tensorflow,代码行数:40,代码来源:monitors_test.py
示例13: test_requests
def test_requests(self):
with tf.Graph().as_default(), tf.Session() as sess:
tf.contrib.framework.create_global_step()
mock_mon = FakeMonitor()
mock_mon2 = FakeMonitor()
hook = learn.monitors.RunHookAdapterForMonitors([mock_mon, mock_mon2])
hook.begin()
mon_sess = monitored_session._HookedSession(sess=sess, hooks=[hook])
a_tensor = tf.constant([0], name='a_tensor')
tf.constant([5], name='another_tensor')
tf.constant([10], name='third_tensor')
mock_mon.requested_tensors = ['another_tensor']
mock_mon2.requested_tensors = ['third_tensor']
sess.run(tf.initialize_all_variables())
output = mon_sess.run(a_tensor)
self.assertEqual(output, [0])
self.assertEqual(mock_mon.output['another_tensor'], [5])
self.assertEqual(mock_mon2.output['third_tensor'], [10])
开发者ID:MostafaGazar,项目名称:tensorflow,代码行数:22,代码来源:monitors_test.py
示例14: test_step_counter
def test_step_counter(self):
with tf.Graph().as_default() as g, tf.Session() as sess:
global_step = tf.contrib.framework.get_or_create_global_step()
train_op = tf.assign_add(global_step, 1)
summary_writer = testing.FakeSummaryWriter(self.log_dir, g)
hook = basic_session_run_hooks.StepCounterHook(
summary_writer=summary_writer, every_n_steps=10)
hook.begin()
sess.run(tf.initialize_all_variables())
mon_sess = monitored_session._HookedSession(sess, [hook])
for _ in range(30):
time.sleep(0.01)
mon_sess.run(train_op)
hook.end(sess)
summary_writer.assert_summaries(
test_case=self,
expected_logdir=self.log_dir,
expected_graph=g,
expected_summaries={})
for step in [11, 21]:
summary_value = summary_writer.summaries[step][0].value[0]
self.assertTrue(summary_value.tag, 'global_step/sec')
# check at least 10 steps per sec is recorded.
self.assertGreater(summary_value.simple_value, 10)
开发者ID:AriaAsuka,项目名称:tensorflow,代码行数:24,代码来源:basic_session_run_hooks_test.py
注:本文中的tensorflow.contrib.learn.python.learn.monitored_session._HookedSession函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论