本文整理汇总了Python中tensorflow.python.framework.ops.get_collection函数的典型用法代码示例。如果您正苦于以下问题:Python get_collection函数的具体用法?Python get_collection怎么用?Python get_collection使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了get_collection函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: __init__
def __init__(self,
checkpoint_dir,
display_steps=100,
maximum_train_steps=None,
do_summary=True,
is_chief=True):
""" Initializes the hook.
Args:
checkpoint_dir: A string, base directory for the checkpoint files.
display_steps: A python integer, display every N steps.
maximum_train_steps: A python integer, the maximum training steps.
do_summary: Whether to save summaries when display.
is_chief: Whether this is the chief process.do_summary:
"""
tf.logging.info("Create DisplayHook.")
self._checkpoint_dir = checkpoint_dir
# display steps
self._display_steps = display_steps
self._maximum_train_steps = maximum_train_steps
self._do_summary = do_summary
self._is_chief = is_chief # not used now
# display values
global_step = training_util.get_global_step()
display_keys = ops.get_collection(Constants.DISPLAY_KEY_COLLECTION_NAME)
display_values = ops.get_collection(Constants.DISPLAY_VALUE_COLLECTION_NAME)
self._display_args = dict(zip(display_keys, display_values))
self._display_args["global_step"] = global_step
# timer & summary writer
self._timer = None
self._logging_timer = None
self._summary_writer = None
开发者ID:KIngpon,项目名称:NJUNMT-tf,代码行数:34,代码来源:hooks.py
示例2: testAddWeight
def testAddWeight(self):
with self.test_session():
layer = base_layers._Layer(name="my_layer")
# Test basic variable creation.
variable = layer._add_variable("my_var", [2, 2], initializer=init_ops.zeros_initializer)
self.assertEqual(variable.name, "my_var:0")
self.assertListEqual(layer.variables, [variable])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [])
self.assertListEqual(layer.variables, ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Test non-trainable variable creation.
# layer._add_variable should work even outside `build` and `call`.
variable_2 = layer._add_variable(
"non_trainable_var", [2, 2], initializer=init_ops.zeros_initializer, trainable=False
)
self.assertListEqual(layer.variables, [variable, variable_2])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [variable_2])
self.assertEqual(len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
# Test with regularizer.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
variable = layer._add_variable(
"reg_var", [2, 2], initializer=init_ops.zeros_initializer, regularizer=regularizer
)
self.assertEqual(len(layer.losses), 1)
开发者ID:BloodD,项目名称:tensorflow,代码行数:28,代码来源:base_test.py
示例3: after_create_session
def after_create_session(self, session, coord): # pylint: disable=unused-argument
"""Does first run which shows the eval metrics before training."""
if ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS):
raise ValueError(
'InMemoryEvaluator does not support saveables other than global '
'variables.')
self._var_name_to_train_var = {
v.name: v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
}
var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set(
self._var_name_to_train_var.keys())
# Filter training var names that are not exist in evaluation
self._var_name_to_train_var = {
v_name: self._var_name_to_train_var[v_name]
for v_name in var_names_to_transfer
}
# Filter eval var names that are not exist in training
self._var_name_to_eval_var = {
v_name: self._var_name_to_eval_var[v_name]
for v_name in var_names_to_transfer
}
with self._graph.as_default():
self._var_feed_op = control_flow_ops.group([
state_ops.assign(self._var_name_to_eval_var[v_name],
self._var_name_to_placeholder[v_name])
for v_name in var_names_to_transfer
])
self._evaluate(session)
开发者ID:ChristinaEricka,项目名称:tensorflow,代码行数:30,代码来源:hooks.py
示例4: testAddWeight
def testAddWeight(self):
layer = base_layers.Layer(name='my_layer')
# Test basic variable creation.
variable = layer.add_variable(
'my_var', [2, 2], initializer=init_ops.zeros_initializer())
self.assertEqual(variable.name, 'my_layer/my_var:0')
self.assertListEqual(layer.variables, [variable])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [])
self.assertListEqual(layer.variables,
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
# Test non-trainable variable creation.
# layer.add_variable should work even outside `build` and `call`.
variable_2 = layer.add_variable(
'non_trainable_var', [2, 2],
initializer=init_ops.zeros_initializer(),
trainable=False)
self.assertListEqual(layer.variables, [variable, variable_2])
self.assertListEqual(layer.trainable_variables, [variable])
self.assertListEqual(layer.non_trainable_variables, [variable_2])
self.assertEqual(
len(ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)), 1)
if context.in_graph_mode():
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
variable = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
开发者ID:keveman,项目名称:tensorflow,代码行数:33,代码来源:base_test.py
示例5: testVariableCollections
def testVariableCollections(self):
with self.test_session():
a = variables_lib2.variable('a', [], collections=['A', 'C'])
b = variables_lib2.variable('b', [], collections=['B', 'C'])
self.assertEquals(a, ops.get_collection('A')[0])
self.assertEquals(b, ops.get_collection('B')[0])
self.assertListEqual([a, b], ops.get_collection('C'))
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:7,代码来源:variables_test.py
示例6: _train_op_fn
def _train_op_fn(loss):
"""Returns the op to optimize the loss."""
train_ops = []
global_step = training_util.get_global_step()
if dnn_logits is not None:
train_ops.append(
dnn_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=dnn_parent_scope)))
if linear_logits is not None:
train_ops.append(
linear_optimizer.minimize(
loss,
var_list=ops.get_collection(
ops.GraphKeys.TRAINABLE_VARIABLES,
scope=linear_parent_scope)))
train_op = control_flow_ops.group(*train_ops)
with ops.control_dependencies([train_op]):
with ops.colocate_with(global_step):
return state_ops.assign_add(global_step, 1)
return head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
train_op_fn=_train_op_fn,
logits=logits)
开发者ID:m-colombo,项目名称:tensorflow,代码行数:30,代码来源:dnn_linear_combined.py
示例7: testSaveAsText
def testSaveAsText(self):
export_dir = self._get_export_dir("test_astext")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:30,代码来源:saved_model_test.py
示例8: testTrainOpGroup
def testTrainOpGroup(self):
export_dir = self._get_export_dir("test_train_op_group")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
train_op = control_flow_ops.group()
sess.run(train_op)
# TODO(karmel): remove explicit call when in the public method.
builder._add_train_op(train_op)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py
示例9: testCustomMainOp
def testCustomMainOp(self):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.Variable(42, name="v3")
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the main_op.
with ops.control_dependencies([main_op.main_op()]):
add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
sess.run(custom_main_op)
builder.add_meta_graph_and_variables(
sess, ["foo"], main_op=custom_main_op)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the main_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py
示例10: testLegacyInitOp
def testLegacyInitOp(self):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the legacy_init_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py
示例11: test_example
def test_example(self):
with self.test_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
session.run(
variables.variables_initializer(
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
session.run(
replicate_model_fn._reduce_metric_variables(number_of_towers=3))
# 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7]
# 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4]
# 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1]
# Reduced = 7.8, 13.8, [19.8, 21.0, 22.2]
# Towers are accumulated in the first tower.
local_metrics = session.run(
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
self.assertNear(7.8, local_metrics[0], 0.01)
self.assertNear(13.8, local_metrics[1], 0.01)
self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
self.assertNear(0.0, local_metrics[3], 0.01)
self.assertNear(0.0, local_metrics[4], 0.01)
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
self.assertNear(0.0, local_metrics[6], 0.01)
self.assertNear(0.0, local_metrics[7], 0.01)
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:29,代码来源:replicate_model_fn_test.py
示例12: testMultipleConvMaskAdded
def testMultipleConvMaskAdded(self):
number_of_layers = 5
kernel_size = 3
base_depth = 4
depth_step = 7
input_tensor = array_ops.ones((8, self.height, self.width, base_depth))
top_layer = input_tensor
for ix in range(number_of_layers):
top_layer = layers.masked_conv2d(top_layer, base_depth +
(ix + 1) * depth_step, kernel_size)
masks = ops.get_collection(core_layers.MASK_COLLECTION)
self.assertEqual(len(masks), number_of_layers)
for ix in range(number_of_layers):
self.assertListEqual(masks[ix].get_shape().as_list(), [
kernel_size, kernel_size, base_depth + ix * depth_step,
base_depth + (ix + 1) * depth_step
])
masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION)
self.assertEqual(len(masked_weight), number_of_layers)
for ix in range(number_of_layers):
self.assertListEqual(masked_weight[ix].get_shape().as_list(), [
kernel_size, kernel_size, base_depth + ix * depth_step,
base_depth + (ix + 1) * depth_step
])
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:30,代码来源:layers_test.py
示例13: testTags
def testTags(self):
export_dir = os.path.join(test.get_temp_dir(), "test_tags")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
# - a single tag (from predefined constants).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
# Graph that updates the single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
# - a single tag (from predefined constants).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph([tag_constants.SERVING])
# Graph that updates the single variable. SavedModel is invoked:
# - to add the model (weights are not updated).
# - multiple custom tags.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 44)
builder.add_meta_graph(["foo", "bar"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with a single predefined tag whose variables were saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with a single predefined tag whose variables were not
# saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, [tag_constants.SERVING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with multiple tags. Provide duplicate tags to test set
# semantics.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo", "bar", "foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Try restoring a graph with a non-existent tag. This should yield a runtime
# error.
with self.test_session(graph=ops.Graph()) as sess:
self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
export_dir)
# Try restoring a graph where a subset of the tags match. Since tag matching
# for meta graph defs follows "all" semantics, this should yield a runtime
# error.
with self.test_session(graph=ops.Graph()) as sess:
self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
export_dir)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:60,代码来源:saved_model_test.py
示例14: test_reduce_is_idempotent
def test_reduce_is_idempotent(self):
with self.test_session() as session:
for tower_id in range(3):
self.create_tower_metrics(tower_id)
session.run(
variables.variables_initializer(
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)))
for _ in range(20):
session.run(
replicate_model_fn._reduce_metric_variables(number_of_towers=3))
local_metrics = session.run(
ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))
self.assertNear(7.8, local_metrics[0], 0.01)
self.assertNear(13.8, local_metrics[1], 0.01)
self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01)
self.assertNear(0.0, local_metrics[3], 0.01)
self.assertNear(0.0, local_metrics[4], 0.01)
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01)
self.assertNear(0.0, local_metrics[6], 0.01)
self.assertNear(0.0, local_metrics[7], 0.01)
self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01)
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:25,代码来源:replicate_model_fn_test.py
示例15: testCreateBN
def testCreateBN(self):
# Call layer.
bn = normalization_layers.BatchNormalization(axis=1)
inputs = random_ops.random_uniform((5, 4, 3), seed=1)
training = array_ops.placeholder(dtype='bool')
outputs = bn.apply(inputs, training=training)
# Verify shape.
self.assertListEqual(outputs.get_shape().as_list(), [5, 4, 3])
# Verify layer attributes.
self.assertEqual(len(bn.updates), 2)
self.assertEqual(len(bn.variables), 4)
self.assertEqual(len(bn.trainable_variables), 2)
self.assertEqual(len(bn.non_trainable_variables), 2)
# Test that updates were created and added to UPDATE_OPS.
self.assertEqual(len(bn.updates), 2)
self.assertListEqual(
ops.get_collection(ops.GraphKeys.UPDATE_OPS), bn.updates)
# Test that weights were created and added to TRAINABLE_VARIABLES.
self.assertListEqual(
ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES),
bn.trainable_variables)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:25,代码来源:normalization_test.py
示例16: _make_training_op
def _make_training_op(training_loss):
"""Training op for the DNN linear combined model."""
train_ops = []
if dnn_logits is not None:
train_ops.append(
optimizers.optimize_loss(
loss=training_loss,
global_step=contrib_variables.get_global_step(),
learning_rate=_DNN_LEARNING_RATE,
optimizer=_get_optimizer(dnn_optimizer),
gradient_multipliers=_extract_embedding_lr_multipliers( # pylint: disable=protected-access
embedding_lr_multipliers, dnn_parent_scope,
dnn_input_scope.name),
clip_gradients=gradient_clip_norm,
variables=ops.get_collection(dnn_parent_scope),
name=dnn_parent_scope,
# Empty summaries, because head already logs "loss" summary.
summaries=[]))
if linear_logits is not None:
train_ops.append(
optimizers.optimize_loss(
loss=training_loss,
global_step=contrib_variables.get_global_step(),
learning_rate=_linear_learning_rate(len(linear_feature_columns)),
optimizer=_get_optimizer(linear_optimizer),
clip_gradients=gradient_clip_norm,
variables=ops.get_collection(linear_parent_scope),
name=linear_parent_scope,
# Empty summaries, because head already logs "loss" summary.
summaries=[]))
return control_flow_ops.group(*train_ops)
开发者ID:Jackhuang945,项目名称:tensorflow,代码行数:32,代码来源:dnn_linear_combined.py
示例17: _testKLPenaltyBoth
def _testKLPenaltyBoth(self, layer_class):
def _make_normal(dtype, *args): # pylint: disable=unused-argument
return normal_lib.Normal(
loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.))
with self.test_session():
layer = layer_class(
filters=2,
kernel_size=3,
bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(),
bias_prior_fn=_make_normal)
if layer_class == prob_layers_lib.Conv1DVariational:
inputs = random_ops.random_uniform([2, 3, 1], seed=1)
elif layer_class == prob_layers_lib.Conv2DVariational:
inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1)
elif layer_class == prob_layers_lib.Conv3DVariational:
inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1)
# No keys.
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(losses), 0)
self.assertListEqual(layer.losses, losses)
_ = layer(inputs)
# Yes keys.
losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(losses), 2)
self.assertListEqual(layer.losses, losses)
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:28,代码来源:layers_conv_variational_test.py
示例18: testTrainOpAfterVariables
def testTrainOpAfterVariables(self):
export_dir = self._get_export_dir("test_train_op_after_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["pre_foo"])
train_op = state_ops.assign_add(v1, v2)
sess.run(train_op)
# TODO(karmel): remove explicit call when in the public method.
builder._add_train_op(train_op)
builder.add_meta_graph(["foo"])
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["pre_foo"], export_dir)
self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:31,代码来源:saved_model_test.py
示例19: testClearExtraneousSavers
def testClearExtraneousSavers(self):
export_dir = os.path.join(test.get_temp_dir(),
"test_clear_extraneous_savers")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Create a variable and a Saver.
with ops.Graph().as_default() as graph:
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Add two Savers, which should be removed in
# add_meta_graph_and_variables() in favor of the locally added one.
saver1 = tf_saver.Saver()
graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
saver2 = tf_saver.Saver()
graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)
# Confirm there are two SaverDefs.
savers = graph.get_collection(ops.GraphKeys.SAVERS)
self.assertEqual(2, len(savers))
# Confirm there are two Save and two Restore ops.
save_op_names = set([x.name for x in graph.get_operations()
if x.type == "SaveV2"])
self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
save_op_names)
restore_op_names = set([x.name for x in graph.get_operations()
if x.type == "RestoreV2"])
self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]),
restore_op_names)
# The SavedModel builder adds its own Saver' for a total of three.
builder.add_meta_graph_and_variables(
sess, [tag_constants.TRAINING], clear_devices=True)
# Save the SavedModel to disk.
builder.save()
# Restore the graph.
with ops.Graph().as_default() as graph:
with self.test_session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Confirm that the reloaded graph has only one SaverDef.
savers = ops.get_collection(ops.GraphKeys.SAVERS)
self.assertEqual(1, len(savers))
# The reloaded graph should have exactly one Save and one Restore op.
save_op_names = set([x.name for x in graph.get_operations()
if x.type == "SaveV2"])
self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
restore_op_names = set([x.name for x in graph.get_operations()
if x.type == "RestoreV2"])
self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py
示例20: testSignatureDefs
def testSignatureDefs(self):
export_dir = self._get_export_dir("test_signature_defs")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable and a single entry in the signature def map.
# SavedModel is invoked to add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
foo_signature = signature_def_utils.build_signature_def(dict(),
dict(), "foo")
builder.add_meta_graph_and_variables(
sess, ["foo"], signature_def_map={"foo_key": foo_signature})
# Graph with the same single variable and multiple entries in the signature
# def map. No weights are saved by SavedModel.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
bar_signature = signature_def_utils.build_signature_def(dict(),
dict(), "bar")
# Also, build a different SignatureDef corresponding to "foo_key" defined
# in the previous graph.
foo_new_signature = signature_def_utils.build_signature_def(dict(),
dict(),
"foo_new")
builder.add_meta_graph(
["bar"],
signature_def_map={
"bar_key": bar_signature,
"foo_key": foo_new_signature
})
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo". The single entry in the SignatureDef map
# corresponding to "foo_key" should exist.
with self.test_session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
foo_signature = foo_graph.signature_def
self.assertEqual(len(foo_signature), 1)
self.assertEqual("foo", foo_signature["foo_key"].method_name)
# Restore the graph with tag "bar". The SignatureDef map should have two
# entries. One corresponding to "bar_key" and another corresponding to the
# new value of "foo_key".
with self.test_session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
bar_signature = bar_graph.signature_def
self.assertEqual(len(bar_signature), 2)
self.assertEqual("bar", bar_signature["bar_key"].method_name)
self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py
注:本文中的tensorflow.python.framework.ops.get_collection函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论