本文整理汇总了Python中tensorflow.contrib.eager.python.network.save_network_checkpoint函数的典型用法代码示例。如果您正苦于以下问题:Python save_network_checkpoint函数的具体用法?Python save_network_checkpoint怎么用?Python save_network_checkpoint使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了save_network_checkpoint函数的9个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testNetworkSaveRestoreAlreadyBuilt
def testNetworkSaveRestoreAlreadyBuilt(self):
net = MyNetwork(name="abcd")
with self.assertRaisesRegexp(
ValueError, "Attempt to save the Network before it was first called"):
network.save_network_checkpoint(net, self.get_temp_dir())
net(constant_op.constant([[2.0]]))
self.evaluate(net.trainable_variables[0].assign([[17.0]]))
self._save_modify_load_network_built(net, global_step=None)
self._save_modify_load_network_built(net, global_step=10)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:9,代码来源:network_test.py
示例2: testRestoreIntoSubNetwork
def testRestoreIntoSubNetwork(self):
class Parent(network.Network):
def __init__(self, name=None):
super(Parent, self).__init__(name=name)
self.first = self.track_layer(MyNetwork())
self.second = self.track_layer(MyNetwork())
def call(self, x):
return self.first(self.second(x))
one = constant_op.constant([[3.]])
whole_model_saver = Parent()
whole_model_saver(one)
self.evaluate(whole_model_saver.variables[0].assign([[15.]]))
self.evaluate(whole_model_saver.variables[1].assign([[16.]]))
whole_model_checkpoint = network.save_network_checkpoint(
whole_model_saver, self.get_temp_dir())
save_from = MyNetwork()
save_from(one)
self.evaluate(save_from.variables[0].assign([[5.]]))
checkpoint = network.save_network_checkpoint(save_from, self.get_temp_dir())
save_into_parent = Parent()
network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
# deferred loading multiple times is fine
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
save_into_parent(one) # deferred loading
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0]))
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
# Try again with the opposite ordering, and we should get different results
# (deferred restoration should happen the same way non-deferred happens,
# with later restorations overwriting older ones).
save_into_parent = Parent()
# deferred loading multiple times is fine
network.restore_network_checkpoint(save_into_parent.first, checkpoint)
network.restore_network_checkpoint(save_into_parent, whole_model_checkpoint)
save_into_parent(one) # deferred loading
# We've overwritten the sub-Network restore.
self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0]))
self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1]))
self.evaluate(save_into_parent.variables[0].assign([[3.]]))
self.evaluate(save_into_parent.variables[1].assign([[4.]]))
network.restore_network_checkpoint(save_into_parent.second, checkpoint)
self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1]))
with self.assertRaisesRegexp(errors_impl.NotFoundError,
"not found in checkpoint"):
# The checkpoint is incompatible.
network.restore_network_checkpoint(save_into_parent, checkpoint)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:53,代码来源:network_test.py
示例3: testDefaultMapCollisionErrors
def testDefaultMapCollisionErrors(self):
one = constant_op.constant([[1.]])
first = core.Dense(1, name="dense", use_bias=False)
first(one)
class Parent(network.Network):
def __init__(self, name=None):
super(Parent, self).__init__(name=name)
self.first = self.track_layer(first)
self.second = self.track_layer(core.Dense(1, use_bias=False))
def call(self, x):
return self.first(self.second(x))
make_checkpoint = Parent()
one = constant_op.constant([[1.]])
make_checkpoint(one)
self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent' resulted in a naming conflict.")):
network.save_network_checkpoint(make_checkpoint, self.get_temp_dir())
class Compatible(network.Network):
def __init__(self, name=None):
super(Compatible, self).__init__(name=name)
self.first = self.track_layer(core.Dense(1, use_bias=False))
def call(self, x):
return self.first(x)
successful_checkpoint = Compatible()
successful_checkpoint(one)
self.evaluate(successful_checkpoint.variables[0].assign([[-1.]]))
checkpoint_path = network.save_network_checkpoint(
successful_checkpoint, self.get_temp_dir())
load_checkpoint = Parent()
load_checkpoint(one)
with self.assertRaisesRegexp(
ValueError,
("The default checkpoint variable name mapping strategy for Network "
"'parent_1' resulted in a naming conflict.")):
network.restore_network_checkpoint(load_checkpoint, checkpoint_path)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:48,代码来源:network_test.py
示例4: testSaveRestoreDefaultGlobalStep
def testSaveRestoreDefaultGlobalStep(self):
net = MyNetwork(name="abcd")
net(constant_op.constant([[2.0]]))
self.evaluate(net.variables[0].assign([[3.]]))
default_global_step = training_util.get_or_create_global_step()
self.evaluate(default_global_step.assign(4242))
save_path = network.save_network_checkpoint(net, self.get_temp_dir())
self.assertIn("abcd-4242", save_path)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:8,代码来源:network_test.py
示例5: testCustomMapCollisionErrors
def testCustomMapCollisionErrors(self):
class Parent(network.Network):
def __init__(self, name=None):
super(Parent, self).__init__(name=name)
self.first = self.track_layer(MyNetwork())
self.second = self.track_layer(MyNetwork())
def call(self, x):
return self.first(self.second(x))
make_checkpoint = Parent()
one = constant_op.constant([[1.]])
make_checkpoint(one)
self.evaluate(make_checkpoint.variables[0].assign([[2.]]))
self.evaluate(make_checkpoint.variables[1].assign([[3.]]))
with self.assertRaisesRegexp(
ValueError,
"The map_func passed to save_network_checkpoint for the Network "
"'parent' resulted in two variables named 'foo'"):
network.save_network_checkpoint(
make_checkpoint, self.get_temp_dir(), map_func=lambda n: "foo")
checkpoint = network.save_network_checkpoint(
network=make_checkpoint.first,
save_path=self.get_temp_dir(),
map_func=lambda n: "foo")
loader = Parent()
network.restore_network_checkpoint(
loader, checkpoint, map_func=lambda n: "foo")
with self.assertRaisesRegexp(
ValueError,
("The map_func passed to restore_network_checkpoint for the Network"
" 'parent_1' resulted in two variables named 'foo'")):
loader(one)
loader = Parent()
loader(one)
with self.assertRaisesRegexp(
ValueError,
("The map_func passed to restore_network_checkpoint for the Network"
" 'parent_2' resulted in two variables named 'foo'")):
network.restore_network_checkpoint(
loader, checkpoint, map_func=lambda n: "foo")
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:43,代码来源:network_test.py
示例6: testNetworkSaveAndRestoreIntoUnbuilt
def testNetworkSaveAndRestoreIntoUnbuilt(self):
save_dir = self.get_temp_dir()
net1 = MyNetwork()
test_input = constant_op.constant([[2.0]])
net1(test_input)
self.evaluate(net1.trainable_variables[0].assign([[17.0]]))
save_path = network.save_network_checkpoint(net1, save_dir)
# With a pre-build restore we should have the same value.
net2 = MyNetwork()
network.restore_network_checkpoint(net2, save_path)
self.assertAllEqual(self.evaluate(net1(test_input)),
self.evaluate(net2(test_input)))
self.assertIsNot(net1.variables[0], net2.variables[0])
self.assertAllEqual(self.evaluate(net1.variables[0]),
self.evaluate(net2.variables[0]))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:15,代码来源:network_test.py
示例7: _save_modify_load_network_built
def _save_modify_load_network_built(self, net, global_step=None):
checkpoint_directory = self.get_temp_dir()
checkpoint_path = network.save_network_checkpoint(
network=net, save_path=checkpoint_directory, global_step=global_step)
input_value = constant_op.constant([[42.0]])
original_output = self.evaluate(net(input_value))
for var in net.variables:
self.evaluate(var.assign(var + 1.))
self.assertGreater(
self.evaluate(net(input_value)),
original_output)
# Either the returned explicit checkpoint path or the directory should work.
network.restore_network_checkpoint(net, save_path=checkpoint_directory)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
for var in net.variables:
self.evaluate(var.assign(var + 2.))
network.restore_network_checkpoint(net, save_path=checkpoint_path)
self.assertAllEqual(
original_output,
self.evaluate(net(input_value)))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:22,代码来源:network_test.py
示例8: testVariableScopeStripping
def testVariableScopeStripping(self):
with variable_scope.variable_scope("scope1"):
with variable_scope.variable_scope("scope2"):
net = MyNetwork()
net(constant_op.constant([[2.0]]))
self.evaluate(net.variables[0].assign([[42.]]))
self.assertEqual(net.name, "scope1/scope2/my_network")
self.assertStartsWith(
expected_start="scope1/scope2/my_network/dense/",
actual=net.trainable_weights[0].name)
save_path = network.save_network_checkpoint(net, self.get_temp_dir())
self.assertIn("scope1_scope2_my_network", save_path)
restore_net = MyNetwork()
# Delayed restoration
network.restore_network_checkpoint(restore_net, save_path)
restore_net(constant_op.constant([[1.0]]))
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
self.evaluate(restore_net.variables[0].assign([[-1.]]))
# Immediate restoration
network.restore_network_checkpoint(restore_net, save_path)
self.assertAllEqual([[42.]],
self.evaluate(restore_net.variables[0]))
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:23,代码来源:network_test.py
示例9: testLoadIntoUnbuiltSharedLayer
def testLoadIntoUnbuiltSharedLayer(self):
class Owner(network.Network):
def __init__(self, name=None):
super(Owner, self).__init__(name=name)
self.first = self.track_layer(core.Dense(
1, name="first_layer", use_bias=False))
def call(self, x):
return self.first(x)
first_owner = Owner()
class User(network.Network):
def __init__(self, use_layer, name=None):
super(User, self).__init__(name=name)
self.first = self.track_layer(use_layer)
self.second = self.track_layer(core.Dense(
1, name="second_layer", use_bias=False))
def call(self, x):
return self.second(self.first(x))
class LikeUserButNotSharing(network.Network):
def __init__(self, name=None):
super(LikeUserButNotSharing, self).__init__(name=name)
self.first = self.track_layer(core.Dense(
1, name="first_layer", use_bias=False))
self.second = self.track_layer(core.Dense(
1, name="second_layer", use_bias=False))
def call(self, x):
return self.second(self.first(x))
checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator")
one = constant_op.constant([[1.0]])
checkpoint_creator(one)
self.assertEqual(2, len(checkpoint_creator.variables))
self.evaluate(checkpoint_creator.variables[0].assign([[5.]]))
self.evaluate(checkpoint_creator.variables[1].assign([[6.]]))
# Re-map the variable names so that with default restore mapping we'll
# attempt to restore into the unbuilt Layer.
name_mapping = {
"checkpoint_creator/first_layer/kernel": "owner/first_layer/kernel",
"checkpoint_creator/second_layer/kernel": "second_layer/kernel",
}
save_path = network.save_network_checkpoint(
checkpoint_creator,
self.get_temp_dir(),
map_func=lambda full_name: name_mapping[full_name])
load_into = User(use_layer=first_owner.first)
network.restore_network_checkpoint(load_into, save_path)
self.assertEqual(0, len(first_owner.variables))
self.assertAllEqual(self.evaluate(checkpoint_creator(one)),
self.evaluate(load_into(one)))
self.assertEqual(1, len(first_owner.variables))
self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0]))
self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1]))
first_owner(one)
self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0]))
# Try again with a garbage collected parent.
first_owner = Owner()
load_into = User(use_layer=first_owner.first)
del first_owner
gc.collect()
def _restore_map_func(original_name):
if original_name.startswith("owner/"):
return original_name.replace("owner/", "owner_1/")
else:
return "user_1/" + original_name
with self.assertRaisesRegexp(ValueError, "garbage collected"):
network.restore_network_checkpoint(
load_into, save_path, map_func=_restore_map_func)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:77,代码来源:network_test.py
注:本文中的tensorflow.contrib.eager.python.network.save_network_checkpoint函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论