本文整理汇总了Python中tensorflow.python.saved_model.loader.load函数的典型用法代码示例。如果您正苦于以下问题:Python load函数的具体用法?Python load怎么用?Python load使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了load函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testSaveAsText
def testSaveAsText(self):
export_dir = os.path.join(
compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("astext"))
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(42, name="v")
sess.run(tf.initialize_all_variables())
self.assertEqual(42, v.eval())
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(43, name="v")
sess.run(tf.initialize_all_variables())
self.assertEqual(43, v.eval())
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(42, tf.get_collection(tf.GraphKeys.VARIABLES)[0].eval())
开发者ID:apollos,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py
示例2: testLegacyInitOp
def testLegacyInitOp(self):
export_dir = self._get_export_dir("test_legacy_init_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], legacy_init_op=legacy_init_op)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the legacy_init_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:33,代码来源:saved_model_test.py
示例3: testSaveAsText
def testSaveAsText(self):
export_dir = self._get_export_dir("test_astext")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable. SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Restore the graph with tag "bar", whose variables were not saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:30,代码来源:saved_model_test.py
示例4: testCustomMainOp
def testCustomMainOp(self):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
v3 = variables.Variable(42, name="v3")
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the main_op.
with ops.control_dependencies([main_op.main_op()]):
add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
sess.run(custom_main_op)
builder.add_meta_graph_and_variables(
sess, ["foo"], main_op=custom_main_op)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
# Evaluates to the sum of the first two variables and assigned as part of
# the main_op, following a restore.
self.assertEqual(3, ops.get_collection("v")[2].eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py
示例5: testTrainOpGroup
def testTrainOpGroup(self):
export_dir = self._get_export_dir("test_train_op_group")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
train_op = control_flow_ops.group()
sess.run(train_op)
# TODO(karmel): remove explicit call when in the public method.
builder._add_train_op(train_op)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertEqual(1, ops.get_collection("v")[0].eval())
self.assertEqual(2, ops.get_collection("v")[1].eval())
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Operation)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py
示例6: testTrainOpAfterVariables
def testTrainOpAfterVariables(self):
export_dir = self._get_export_dir("test_train_op_after_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
v2 = variables.Variable(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["pre_foo"])
train_op = state_ops.assign_add(v1, v2)
sess.run(train_op)
# TODO(karmel): remove explicit call when in the public method.
builder._add_train_op(train_op)
builder.add_meta_graph(["foo"])
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
self.assertIsInstance(
ops.get_collection(constants.TRAIN_OP_KEY)[0], ops.Tensor)
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["pre_foo"], export_dir)
self.assertFalse(ops.get_collection(constants.TRAIN_OP_KEY))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:31,代码来源:saved_model_test.py
示例7: testCustomSaveable
def testCustomSaveable(self):
export_dir = self._get_export_dir("custom_saveable")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
# CheckpointedOp is a key-value table that can be saved across sessions.
# The table register itself in SAVEABLE_OBJECTS collection.
v1 = saver_test_utils.CheckpointedOp(name="v1")
variables.global_variables_initializer().run()
v1.insert("k1", 3.0).run()
# Once the table is restored, we can access it through this reference.
ops.add_to_collection("table_ref", v1.table_ref)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Save the SavedModel to disk.
builder.save()
with session.Session(
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
loader.load(sess, ["foo"], export_dir)
# Instantiate a wrapper object from the checkpointed reference.
v1 = saver_test_utils.CheckpointedOp(
name="v1", table_ref=ops.get_collection("table_ref")[0])
self.assertEqual(b"k1", v1.keys().eval())
self.assertEqual(3.0, v1.values().eval())
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:28,代码来源:saved_model_test.py
示例8: testGraphWithoutVariables
def testGraphWithoutVariables(self):
export_dir = self._get_export_dir("test_graph_has_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with no variables.
with self.test_session(graph=ops.Graph()) as sess:
constant_5_name = constant_op.constant(5.0).name
builder.add_meta_graph_and_variables(sess, ["foo"])
# Second graph with no variables
with self.test_session(graph=ops.Graph()) as sess:
constant_6_name = constant_op.constant(6.0).name
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo".
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
b = constant_op.constant(6.0)
c = a * b
self.assertEqual(30.0, sess.run(c))
# Restore the graph with tag "bar".
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
# Read the constant a from the graph.
a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
b = constant_op.constant(5.0)
c = a * b
self.assertEqual(30.0, sess.run(c))
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:34,代码来源:saved_model_test.py
示例9: export_fn
def export_fn(estimator, export_dir, checkpoint_path=None, eval_result=None):
"""A wrapper to export to SavedModel, and convert it to other formats."""
result_dir = base_strategy.export(estimator, export_dir,
checkpoint_path,
eval_result)
with ops.Graph().as_default() as graph:
with tf_session.Session(graph=graph) as sess:
saved_model_loader.load(
sess, [tag_constants.SERVING], result_dir)
# Note: This is GTFlow internal API and might change.
ensemble_model = graph.get_operation_by_name(
"ensemble_model/TreeEnsembleSerialize")
_, dfec_str = sess.run(ensemble_model.outputs)
dtec = tree_config_pb2.DecisionTreeEnsembleConfig()
dtec.ParseFromString(dfec_str)
# Export the result in the same folder as the saved model.
if convert_fn:
convert_fn(dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices),
len(sparse_int_indices), result_dir, eval_result)
feature_importances = _get_feature_importances(
dtec, sorted_feature_names,
len(dense_floats),
len(sparse_float_indices), len(sparse_int_indices))
sorted_by_importance = sorted(
feature_importances.items(), key=lambda x: -x[1])
assets_dir = os.path.join(result_dir, "assets.extra")
gfile.MakeDirs(assets_dir)
with gfile.GFile(os.path.join(assets_dir, "feature_importances"),
"w") as f:
f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance))
return result_dir
开发者ID:jiayouwyhit,项目名称:tensorflow,代码行数:33,代码来源:custom_export_strategy.py
示例10: testVariables
def testVariables(self):
export_dir = os.path.join(
compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("variables"))
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with two variables. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=tf.Graph()) as sess:
v1 = tf.Variable(1, name="v1")
v2 = tf.Variable(2, name="v2")
sess.run(tf.initialize_all_variables())
self.assertEqual(1, v1.eval())
self.assertEqual(2, v2.eval())
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with a single variable (subset of the variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v2 = tf.Variable(3, name="v2")
sess.run(tf.initialize_all_variables())
self.assertEqual(3, v2.eval())
builder.add_meta_graph(["bar"])
# Graph with a single variable (disjoint set of variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v3 = tf.Variable(4, name="v3")
sess.run(tf.initialize_all_variables())
self.assertEqual(4, v3.eval())
builder.add_meta_graph(["baz"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
self.assertEqual(len(collection_vars), 2)
self.assertEqual(1, collection_vars[0].eval())
self.assertEqual(2, collection_vars[1].eval())
# Restore the graph with tag "bar", whose variables were not saved. Only the
# subset of the variables added to the graph will be restored with the
# checkpointed value.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_vars = tf.get_collection(tf.GraphKeys.VARIABLES)
self.assertEqual(len(collection_vars), 1)
self.assertEqual(2, collection_vars[0].eval())
# Try restoring the graph with tag "baz", whose variables were not saved.
# Since this graph has a disjoint set of variables from the set that was
# saved, this should raise an error.
with self.test_session(graph=tf.Graph()) as sess:
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
开发者ID:apollos,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py
示例11: testClearExtraneousSavers
def testClearExtraneousSavers(self):
export_dir = os.path.join(test.get_temp_dir(),
"test_clear_extraneous_savers")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Create a variable and a Saver.
with ops.Graph().as_default() as graph:
with session.Session(
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Add two Savers, which should be removed in
# add_meta_graph_and_variables() in favor of the locally added one.
saver1 = tf_saver.Saver()
graph.add_to_collection(ops.GraphKeys.SAVERS, saver1)
saver2 = tf_saver.Saver()
graph.add_to_collection(ops.GraphKeys.SAVERS, saver2)
# Confirm there are two SaverDefs.
savers = graph.get_collection(ops.GraphKeys.SAVERS)
self.assertEqual(2, len(savers))
# Confirm there are two Save and two Restore ops.
save_op_names = set([x.name for x in graph.get_operations()
if x.type == "SaveV2"])
self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]),
save_op_names)
restore_op_names = set([x.name for x in graph.get_operations()
if x.type == "RestoreV2"])
self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]),
restore_op_names)
# The SavedModel builder adds its own Saver' for a total of three.
builder.add_meta_graph_and_variables(
sess, [tag_constants.TRAINING], clear_devices=True)
# Save the SavedModel to disk.
builder.save()
# Restore the graph.
with ops.Graph().as_default() as graph:
with self.test_session(graph=graph) as sess:
loader.load(sess, [tag_constants.TRAINING], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
# Confirm that the reloaded graph has only one SaverDef.
savers = ops.get_collection(ops.GraphKeys.SAVERS)
self.assertEqual(1, len(savers))
# The reloaded graph should have exactly one Save and one Restore op.
save_op_names = set([x.name for x in graph.get_operations()
if x.type == "SaveV2"])
self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names)
restore_op_names = set([x.name for x in graph.get_operations()
if x.type == "RestoreV2"])
self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names)
开发者ID:adityaatluri,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py
示例12: testSignatureDefs
def testSignatureDefs(self):
export_dir = self._get_export_dir("test_signature_defs")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable and a single entry in the signature def map.
# SavedModel is invoked to add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
foo_signature = signature_def_utils.build_signature_def(dict(),
dict(), "foo")
builder.add_meta_graph_and_variables(
sess, ["foo"], signature_def_map={"foo_key": foo_signature})
# Graph with the same single variable and multiple entries in the signature
# def map. No weights are saved by SavedModel.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
bar_signature = signature_def_utils.build_signature_def(dict(),
dict(), "bar")
# Also, build a different SignatureDef corresponding to "foo_key" defined
# in the previous graph.
foo_new_signature = signature_def_utils.build_signature_def(dict(),
dict(),
"foo_new")
builder.add_meta_graph(
["bar"],
signature_def_map={
"bar_key": bar_signature,
"foo_key": foo_new_signature
})
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo". The single entry in the SignatureDef map
# corresponding to "foo_key" should exist.
with self.test_session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
foo_signature = foo_graph.signature_def
self.assertEqual(len(foo_signature), 1)
self.assertEqual("foo", foo_signature["foo_key"].method_name)
# Restore the graph with tag "bar". The SignatureDef map should have two
# entries. One corresponding to "bar_key" and another corresponding to the
# new value of "foo_key".
with self.test_session(graph=ops.Graph()) as sess:
bar_graph = loader.load(sess, ["bar"], export_dir)
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
bar_signature = bar_graph.signature_def
self.assertEqual(len(bar_signature), 2)
self.assertEqual("bar", bar_signature["bar_key"].method_name)
self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:59,代码来源:saved_model_test.py
示例13: testStripDefaultAttrsInconsistentConsumerDefaults
def testStripDefaultAttrsInconsistentConsumerDefaults(self):
if ops._USE_C_API: return # TODO(skyewm): get this working
export_dir = self._get_export_dir(
"test_strip_default_attrs_no_consumer_defaults")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled. This must remove the following
# defaults for the "Complex" Op:
# o "T" : float32. (input type)
# o "Tout" : complex64. (output type)
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
sess, ["foo"], strip_default_attrs=True)
# Save the SavedModel to disk in text format.
builder.save(as_text=True)
# Update the Op registry to remove defaults for all attrs("T", "Tout") from
# the "Complex" OpDef.
complex_op_def = op_def_registry.get_registered_ops()["Complex"]
original_complex_op_def = op_def_pb2.OpDef()
original_complex_op_def.CopyFrom(complex_op_def)
for attr_def in complex_op_def.attr:
attr_def.ClearField("default_value")
# Loading the SavedModel via the loader must fail because the SavedModel
# does not have any attr values for the "Complex" node and the current
# op registry does not have have any default values for the "Complex" op.
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
ValueError,
"Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
loader.load(sess, ["foo"], export_dir)
# Update the Op registry to change the defaults for attr "Tout"
# (complex64 -> complex128).
complex_op_def.CopyFrom(original_complex_op_def)
for attr_def in complex_op_def.attr:
if attr_def.name == "Tout":
attr_def.default_value.type = types_pb2.DT_COMPLEX128
# Loading the SavedModel via the loader must set "Tout" attr_value for the
# "Complex" node according to the latest defaults (complex128). This is
# expected to fail the model import as there is no OpKernel registered to
# handle attrs "T" (float32) and "Tout" (complex128).
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
".*No OpKernel was registered to support Op \'Complex\' with these "
"attrs..*"):
loader.load(sess, ["foo"], export_dir)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:57,代码来源:saved_model_test.py
示例14: freeze_saved_model
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
from SignatureDef when none are provided.
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
frozen_graph_def: Frozen GraphDef.
in_tensors: List of input tensors for the graph.
out_tensors: List of output tensors for the graph.
Raises:
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
signature_def = _get_signature_def(meta_graph, signature_key)
inputs, outputs = _get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
if constants.ASSETS_KEY in collection_def:
raise ValueError("SavedModels with assets/ directory are not supported.")
graph = ops.Graph()
with session.Session(graph=graph) as sess:
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
set_tensor_shapes(in_tensors, input_shapes)
output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
return frozen_graph_def, in_tensors, out_tensors
开发者ID:AnishShah,项目名称:tensorflow,代码行数:55,代码来源:convert_saved_model.py
示例15: _TestStaticOp
def _TestStaticOp(self, use_function_backup):
if not is_tensorrt_enabled():
return
tmp_dir = self.get_temp_dir()
input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
self._WriteInputSavedModel(input_saved_model_dir)
output_graph_def = self._ConvertGraph(
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
maximum_cached_engines=2, # This is noop, added just for testing.
use_function_backup=use_function_backup)
# Test the output GraphDef.
with ops.Graph().as_default():
importer.import_graph_def(output_graph_def, name="")
with self.session(config=self._GetConfigProto()) as sess:
# Run with batch size 1, the default engine embedded in the graphdef
# will be used.
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
# Test the output SavedModel
with ops.Graph().as_default():
with self.session(config=self._GetConfigProto()) as sess:
loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
# Run with batch size 1, the default engine embedded in the graphdef
# will be used.
self._TestRun(
sess,
1,
use_function_backup=use_function_backup,
expect_engine_is_run=True)
# Run with batch size 2, which exceed the max_batch_size, it should try
# to fall back to TF function.
self._TestRun(
sess,
2,
use_function_backup=use_function_backup,
expect_engine_is_run=False)
开发者ID:perfmjs,项目名称:tensorflow,代码行数:51,代码来源:trt_convert_test.py
示例16: __init__
def __init__(self,
export_dir,
signature_def_key=None,
signature_def=None,
input_names=None,
output_names=None,
tags=None,
graph=None):
"""Initialize a `CoreEstimatorPredictor`.
Args:
export_dir: a path to a directory containing a `SavedModel`.
signature_def_key: Optional string specifying the signature to use. If
`None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
`signature_def_key` and `signature_def` should be specified.
signature_def: A `SignatureDef` proto specifying the inputs and outputs
for prediction. Only one of `signature_def_key` and `signature_def`
should be specified.
input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
that represent the input. The keys can be any string of the user's
choosing.
output_names: A dictionary mapping strings to `Tensor`s in the
`SavedModel` that represent the output. The keys can be any string of
the user's choosing.
tags: Optional. Tags that will be used to retrieve the correct
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
Raises:
ValueError: If more than one of signature_def_key OR signature_def OR
(input_names AND output_names) is specified.
"""
_check_signature_arguments(
signature_def_key, signature_def, input_names, output_names)
tags = tags or DEFAULT_TAGS
self._graph = graph or ops.Graph()
with self._graph.as_default():
self._session = session.Session()
loader.load(self._session, tags.split(','), export_dir)
if input_names is None:
if signature_def is None:
signature_def = _get_signature_def(signature_def_key, export_dir, tags)
input_names = {k: v.name for k, v in signature_def.inputs.items()}
output_names = {k: v.name for k, v in signature_def.outputs.items()}
self._feed_tensors = {k: self._graph.get_tensor_by_name(v)
for k, v in input_names.items()}
self._fetch_tensors = {k: self._graph.get_tensor_by_name(v)
for k, v in output_names.items()}
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:51,代码来源:saved_model_predictor.py
示例17: testVariables
def testVariables(self):
export_dir = self._get_export_dir("test_variables")
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with two variables. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v1", 1)
self._init_and_validate_variable(sess, "v2", 2)
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with a single variable (subset of the variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v2", 3)
builder.add_meta_graph(["bar"])
# Graph with a single variable (disjoint set of variables from the previous
# graph whose weights were saved). SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v3", 4)
builder.add_meta_graph(["baz"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo", whose variables were saved.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(len(collection_vars), 2)
self.assertEqual(1, collection_vars[0].eval())
self.assertEqual(2, collection_vars[1].eval())
# Restore the graph with tag "bar", whose variables were not saved. Only the
# subset of the variables added to the graph will be restored with the
# checkpointed value.
with self.test_session(graph=ops.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertEqual(len(collection_vars), 1)
self.assertEqual(2, collection_vars[0].eval())
# Try restoring the graph with tag "baz", whose variables were not saved.
# Since this graph has a disjoint set of variables from the set that was
# saved, this should raise an error.
with self.test_session(graph=ops.Graph()) as sess:
self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
export_dir)
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:51,代码来源:saved_model_test.py
示例18: testCollections
def testCollections(self):
export_dir = os.path.join(
compat.as_bytes(tf.test.get_temp_dir()), compat.as_bytes("collections"))
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Graph with a single variable added to a collection. SavedModel invoked to:
# - add with weights.
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(42, name="v")
tf.add_to_collection("foo_vars", v)
sess.run(tf.initialize_all_variables())
self.assertEqual(42, v.eval())
builder.add_meta_graph_and_variables(sess, ["foo"])
# Graph with the same single variable added to a different collection.
# SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.test_session(graph=tf.Graph()) as sess:
v = tf.Variable(43, name="v")
tf.add_to_collection("bar_vars", v)
sess.run(tf.initialize_all_variables())
self.assertEqual(43, v.eval())
builder.add_meta_graph(["bar"])
# Save the SavedModel to disk.
builder.save()
# Restore the graph with tag "foo", whose variables were saved. The
# collection 'foo_vars' should contain a single element. The collection
# 'bar_vars' should not be found.
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["foo"], export_dir)
collection_foo_vars = tf.get_collection("foo_vars")
self.assertEqual(len(collection_foo_vars), 1)
self.assertEqual(42, collection_foo_vars[0].eval())
self.assertEqual(len(tf.get_collection("bar_vars")), 0)
# Restore the graph with tag "bar", whose variables were not saved. The
# collection-def exported as part of the meta graph def is updated to
# reflect the new collection. The value of the variable in the
# collection-def corresponds to the saved value (from the previous graph
# with tag "foo").
with self.test_session(graph=tf.Graph()) as sess:
loader.load(sess, ["bar"], export_dir)
collection_bar_vars = tf.get_collection("bar_vars")
self.assertEqual(len(collection_bar_vars), 1)
self.assertEqual(42, collection_bar_vars[0].eval())
self.assertEqual(len(tf.get_collection("foo_vars")), 0)
开发者ID:apollos,项目名称:tensorflow,代码行数:50,代码来源:saved_model_test.py
示例19: testAssets
def testAssets(self):
export_dir = self._get_export_dir("test_assets")
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.test_session(graph=ops.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build an asset collection.
ignored_filepath = os.path.join(
compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
file_io.write_string_to_file(ignored_filepath, "will be ignored")
asset_collection = self._build_asset_collection("hello42.txt",
"foo bar baz",
"asset_file_tensor")
builder.add_meta_graph_and_variables(
sess, ["foo"], assets_collection=asset_collection)
# Save the SavedModel to disk.
builder.save()
with self.test_session(graph=ops.Graph()) as sess:
foo_graph = loader.load(sess, ["foo"], export_dir)
self._validate_asset_collection(export_dir, foo_graph.collection_def,
"hello42.txt", "foo bar baz",
"asset_file_tensor:0")
|
请发表评论