本文整理汇总了Python中tensorflow.python.saved_model.utils.build_tensor_info函数的典型用法代码示例。如果您正苦于以下问题:Python build_tensor_info函数的具体用法?Python build_tensor_info怎么用?Python build_tensor_info使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了build_tensor_info函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: predict_signature_def
def predict_signature_def(inputs, outputs):
"""Creates prediction signature from given inputs and outputs.
This function produces signatures intended for use with the TensorFlow Serving
Predict API (tensorflow_serving/apis/prediction_service.proto). This API
imposes no constraints on the input and output types.
Args:
inputs: dict of string to `Tensor`.
outputs: dict of string to `Tensor`.
Returns:
A prediction-flavored signature_def.
Raises:
ValueError: If inputs or outputs is `None`.
"""
if inputs is None or not inputs:
raise ValueError('Prediction inputs cannot be None or empty.')
if outputs is None or not outputs:
raise ValueError('Prediction outputs cannot be None or empty.')
signature_inputs = {key: utils.build_tensor_info(tensor)
for key, tensor in inputs.items()}
signature_outputs = {key: utils.build_tensor_info(tensor)
for key, tensor in outputs.items()}
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.PREDICT_METHOD_NAME)
return signature_def
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:32,代码来源:signature_def_utils_impl.py
示例2: setUp
def setUp(self):
"""Write test SavedModels to a temp directory."""
with session.Session(graph=ops.Graph()) as sess:
x = variables.VariableV1(5, name="x")
y = variables.VariableV1(11, name="y")
z = x + y
self.evaluate(variables.global_variables_initializer())
foo_sig_def = signature_def_utils.build_signature_def(
{"foo_input": utils.build_tensor_info(x)},
{"foo_output": utils.build_tensor_info(z)})
bar_sig_def = signature_def_utils.build_signature_def(
{"bar_x": utils.build_tensor_info(x),
"bar_y": utils.build_tensor_info(y)},
{"bar_z": utils.build_tensor_info(z)})
builder = saved_model_builder.SavedModelBuilder(SIMPLE_ADD_SAVED_MODEL)
builder.add_meta_graph_and_variables(
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def})
builder.save()
# Write SavedModel with a main_op
assign_op = control_flow_ops.group(state_ops.assign(y, 7))
builder = saved_model_builder.SavedModelBuilder(SAVED_MODEL_WITH_MAIN_OP)
builder.add_meta_graph_and_variables(
sess, ["foo_graph"], {"foo": foo_sig_def, "bar": bar_sig_def},
main_op=assign_op)
builder.save()
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:29,代码来源:loader_test.py
示例3: regression_signature_def
def regression_signature_def(examples, predictions):
"""Creates regression signature from given examples and predictions.
Args:
examples: `Tensor`.
predictions: `Tensor`.
Returns:
A regression-flavored signature_def.
Raises:
ValueError: If examples is `None`.
"""
if examples is None:
raise ValueError('Regression examples cannot be None.')
if not isinstance(examples, ops.Tensor):
raise ValueError('Regression examples must be a string Tensor.')
if predictions is None:
raise ValueError('Regression predictions cannot be None.')
input_tensor_info = utils.build_tensor_info(examples)
if input_tensor_info.dtype != types_pb2.DT_STRING:
raise ValueError('Regression examples must be a string Tensor.')
signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
output_tensor_info = utils.build_tensor_info(predictions)
if output_tensor_info.dtype != types_pb2.DT_FLOAT:
raise ValueError('Regression output must be a float Tensor.')
signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.REGRESS_METHOD_NAME)
return signature_def
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:35,代码来源:signature_def_utils_impl.py
示例4: testBuildSignatureDef
def testBuildSignatureDef(self):
x = tf.placeholder(tf.float32, 1, name="x")
x_tensor_info = utils.build_tensor_info(x)
inputs = dict()
inputs["foo-input"] = x_tensor_info
y = tf.placeholder(tf.float32, name="y")
y_tensor_info = utils.build_tensor_info(y)
outputs = dict()
outputs["foo-output"] = y_tensor_info
signature_def = utils.build_signature_def(inputs, outputs,
"foo-method-name")
self.assertEqual("foo-method-name", signature_def.method_name)
# Check inputs in signature def.
self.assertEqual(1, len(signature_def.inputs))
x_tensor_info_actual = signature_def.inputs["foo-input"]
self.assertEqual("x:0", x_tensor_info_actual.name)
self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
# Check outputs in signature def.
self.assertEqual(1, len(signature_def.outputs))
y_tensor_info_actual = signature_def.outputs["foo-output"]
self.assertEqual("y:0", y_tensor_info_actual.name)
self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
开发者ID:ComeOnGetMe,项目名称:tensorflow,代码行数:29,代码来源:utils_test.py
示例5: build_inputs_and_outputs
def build_inputs_and_outputs(self):
if self.frame_features:
serialized_examples = tf.placeholder(tf.string, shape=(None,))
fn = lambda x: self.build_prediction_graph(x)
video_id_output, top_indices_output, top_predictions_output = (
tf.map_fn(fn, serialized_examples,
dtype=(tf.string, tf.int32, tf.float32)))
else:
serialized_examples = tf.placeholder(tf.string, shape=(None,))
video_id_output, top_indices_output, top_predictions_output = (
self.build_prediction_graph(serialized_examples))
inputs = {"example_bytes":
saved_model_utils.build_tensor_info(serialized_examples)}
outputs = {
"video_id": saved_model_utils.build_tensor_info(video_id_output),
"class_indexes": saved_model_utils.build_tensor_info(top_indices_output),
"predictions": saved_model_utils.build_tensor_info(top_predictions_output)}
return inputs, outputs
开发者ID:lvaleriu,项目名称:Youtube-8M-WILLOW,代码行数:27,代码来源:export_model.py
示例6: regression_signature_def
def regression_signature_def(examples, predictions):
"""Creates regression signature from given examples and predictions.
Args:
examples: `Tensor`.
predictions: `Tensor`.
Returns:
A regression-flavored signature_def.
Raises:
ValueError: If examples is `None`.
"""
if examples is None:
raise ValueError('examples cannot be None for regression.')
if predictions is None:
raise ValueError('predictions cannot be None for regression.')
input_tensor_info = utils.build_tensor_info(examples)
signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
output_tensor_info = utils.build_tensor_info(predictions)
signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.REGRESS_METHOD_NAME)
return signature_def
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:28,代码来源:signature_def_utils.py
示例7: predict_signature_def
def predict_signature_def(inputs, outputs):
"""Creates prediction signature from given inputs and outputs.
Args:
inputs: dict of string to `Tensor`.
outputs: dict of string to `Tensor`.
Returns:
A prediction-flavored signature_def.
Raises:
ValueError: If inputs or outputs is `None`.
"""
if inputs is None or not inputs:
raise ValueError('Prediction inputs cannot be None or empty.')
if outputs is None or not outputs:
raise ValueError('Prediction outputs cannot be None or empty.')
signature_inputs = {key: utils.build_tensor_info(tensor)
for key, tensor in inputs.items()}
signature_outputs = {key: utils.build_tensor_info(tensor)
for key, tensor in outputs.items()}
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.PREDICT_METHOD_NAME)
return signature_def
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:28,代码来源:signature_def_utils_impl.py
示例8: export
def export(model_version, model_dir, sess, x, y_op):
"""导出tensorflow_serving可用的模型
SavedModel(tensorflow.python.saved_model)提供了一种跨语言格式来保存和恢复训练后的TensorFlow模型。它使用方法签名来定义Graph的输入和输出,使上层系统能够更方便地生成、调用或转换TensorFlow模型。
SavedModelBuilder类提供保存Graphs、Variables及Assets的方法。所保存的Graphs必须标注用途标签。在这个实例中我们打算将模型用于服务而非训练,因此我们用SavedModel预定义好的tag_constant.Serving标签。
为了方便地构建签名,SavedModel提供了signature_def_utils API。我们通过signature_def_utils.build_signature_def()来构建predict_signature。一个predict_signature至少包含以下参数:
* inputs = {'x': tensor_info_x} 指定输入的tensor信息
* outputs = {'y': tensor_info_y} 指定输出的tensor信息
* method_name = signature_constants.PREDICT_METHOD_NAME
method_name定义方法名,它的值应该是tensorflow/serving/predict、tensorflow/serving/classify和tensorflow/serving/regress三者之一。Builder标签用来明确Meta Graph被加载的方式,只接受serve和train两种类型。
"""
if model_version <= 0:
logging.warning('Please specify a positive value for version number.')
sys.exit()
path = os.path.dirname(os.path.abspath(model_dir))
if os.path.isdir(path) == False:
logging.warning('Path (%s) not exists, making directories...', path)
os.makedirs(path)
export_path = os.path.join(
compat.as_bytes(model_dir),
compat.as_bytes(str(model_version)))
if os.path.isdir(export_path) == True:
logging.warning('Path (%s) exists, removing directories...', export_path)
shutil.rmtree(export_path)
builder = saved_model_builder.SavedModelBuilder(export_path)
tensor_info_x = utils.build_tensor_info(x)
tensor_info_y = utils.build_tensor_info(y_op)
prediction_signature = signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'y': tensor_info_y},
# signature_constants.CLASSIFY_METHOD_NAME = "tensorflow/serving/classify"
# signature_constants.PREDICT_METHOD_NAME = "tensorflow/serving/predict"
# signature_constants.REGRESS_METHOD_NAME = "tensorflow/serving/regress"
# 如果缺失method_name会报错:
# grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.INTERNAL, details="Expected prediction signature method_name to be one of {tensorflow/serving/predict, tensorflow/serving/classify, tensorflow/serving/regress}. Was: ")
method_name=signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess,
# tag_constants.SERVING = "serve"
# tag_constants.TRAINING = "train"
# 如果只有train标签,TensorFlow Serving加载时会报错:
# E tensorflow_serving/core/aspired_versions_manager.cc:351] Servable {name: default version: 2} cannot be loaded: Not found: Could not find meta graph def matching supplied tags.
[tag_constants.SERVING],
signature_def_map={
'predict_text': prediction_signature,
# 如果缺失会报错:
# grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.FAILED_PRECONDITION, details="Default serving signature key not found.")
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature
})
builder.save()
开发者ID:lacatc,项目名称:text-antispam,代码行数:56,代码来源:rnn_classifier.py
示例9: _make_signature
def _make_signature(inputs, outputs, name=None):
input_info = {
input_name: utils.build_tensor_info(tensor)
for input_name, tensor in inputs.items()
}
output_info = {
output_name: utils.build_tensor_info(tensor)
for output_name, tensor in outputs.items()
}
return signature_def_utils_impl.build_signature_def(input_info, output_info,
name)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:11,代码来源:signature_def_utils_test.py
示例10: testGetSignatureDefByKey
def testGetSignatureDefByKey(self):
x = array_ops.placeholder(dtypes.float32, 1, name="x")
x_tensor_info = utils.build_tensor_info(x)
y = array_ops.placeholder(dtypes.float32, name="y")
y_tensor_info = utils.build_tensor_info(y)
foo_signature_def = signature_def_utils.build_signature_def({
"foo-input": x_tensor_info
}, {"foo-output": y_tensor_info}, "foo-method-name")
bar_signature_def = signature_def_utils.build_signature_def({
"bar-input": x_tensor_info
}, {"bar-output": y_tensor_info}, "bar-method-name")
meta_graph_def = meta_graph_pb2.MetaGraphDef()
self._add_to_signature_def_map(
meta_graph_def, {"foo": foo_signature_def,
"bar": bar_signature_def})
# Look up a key that does not exist in the SignatureDefMap.
missing_key = "missing-key"
with self.assertRaisesRegexp(
ValueError,
"No SignatureDef with key '%s' found in MetaGraphDef" % missing_key):
signature_def_contrib_utils.get_signature_def_by_key(
meta_graph_def, missing_key)
# Look up the key, `foo` which exists in the SignatureDefMap.
foo_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
meta_graph_def, "foo")
self.assertTrue("foo-method-name", foo_signature_def.method_name)
# Check inputs in signature def.
self.assertEqual(1, len(foo_signature_def.inputs))
self._check_tensor_info(foo_signature_def.inputs, "foo-input", "x:0")
# Check outputs in signature def.
self.assertEqual(1, len(foo_signature_def.outputs))
self._check_tensor_info(foo_signature_def.outputs, "foo-output", "y:0")
# Look up the key, `bar` which exists in the SignatureDefMap.
bar_signature_def = signature_def_contrib_utils.get_signature_def_by_key(
meta_graph_def, "bar")
self.assertTrue("bar-method-name", bar_signature_def.method_name)
# Check inputs in signature def.
self.assertEqual(1, len(bar_signature_def.inputs))
self._check_tensor_info(bar_signature_def.inputs, "bar-input", "x:0")
# Check outputs in signature def.
self.assertEqual(1, len(bar_signature_def.outputs))
self._check_tensor_info(bar_signature_def.outputs, "bar-output", "y:0")
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:51,代码来源:signature_def_utils_test.py
示例11: _WriteInputSavedModel
def _WriteInputSavedModel(self, input_saved_model_dir):
"""Write the saved model as an input for testing."""
g, var, inp, out = self._GetGraph()
signature_def = signature_def_utils.build_signature_def(
inputs={"myinput": utils.build_tensor_info(inp)},
outputs={"myoutput": utils.build_tensor_info(out)},
method_name=signature_constants.PREDICT_METHOD_NAME)
saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
with self.session(graph=g, config=self._GetConfigProto()) as sess:
sess.run(var.initializer)
saved_model_builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={"mypredict": signature_def})
saved_model_builder.save()
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:14,代码来源:trt_convert_test.py
示例12: classification_signature_def
def classification_signature_def(examples, classes, scores):
"""Creates classification signature from given examples and predictions.
This function produces signatures intended for use with the TensorFlow Serving
Classify API (tensorflow_serving/apis/prediction_service.proto), and so
constrains the input and output types to those allowed by TensorFlow Serving.
Args:
examples: A string `Tensor`, expected to accept serialized tf.Examples.
classes: A string `Tensor`. Note that the ClassificationResponse message
requires that class labels are strings, not integers or anything else.
scores: a float `Tensor`.
Returns:
A classification-flavored signature_def.
Raises:
ValueError: If examples is `None`.
"""
if examples is None:
raise ValueError('Classification examples cannot be None.')
if not isinstance(examples, ops.Tensor):
raise ValueError('Classification examples must be a string Tensor.')
if classes is None and scores is None:
raise ValueError('Classification classes and scores cannot both be None.')
input_tensor_info = utils.build_tensor_info(examples)
if input_tensor_info.dtype != types_pb2.DT_STRING:
raise ValueError('Classification examples must be a string Tensor.')
signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
signature_outputs = {}
if classes is not None:
classes_tensor_info = utils.build_tensor_info(classes)
if classes_tensor_info.dtype != types_pb2.DT_STRING:
raise ValueError('Classification classes must be a string Tensor.')
signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
classes_tensor_info)
if scores is not None:
scores_tensor_info = utils.build_tensor_info(scores)
if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
raise ValueError('Classification scores must be a float Tensor.')
signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
scores_tensor_info)
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.CLASSIFY_METHOD_NAME)
return signature_def
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:50,代码来源:signature_def_utils_impl.py
示例13: build_graph_helper
def build_graph_helper():
g = ops.Graph()
with g.as_default():
x = variables.VariableV1(5, name="x")
y = variables.VariableV1(11, name="y")
z = x + y
foo_sig_def = signature_def_utils.build_signature_def({
"foo_input": utils.build_tensor_info(x)
}, {"foo_output": utils.build_tensor_info(z)})
bar_sig_def = signature_def_utils.build_signature_def({
"bar_x": utils.build_tensor_info(x),
"bar_y": utils.build_tensor_info(y)
}, {"bar_z": utils.build_tensor_info(z)})
return g, {"foo": foo_sig_def, "bar": bar_sig_def}, y
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:15,代码来源:loader_test.py
示例14: testBuildTensorInfoDense
def testBuildTensorInfoDense(self):
x = array_ops.placeholder(dtypes.float32, 1, name="x")
x_tensor_info = utils.build_tensor_info(x)
self.assertEqual("x:0", x_tensor_info.name)
self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info.dtype)
self.assertEqual(1, len(x_tensor_info.tensor_shape.dim))
self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:7,代码来源:utils_test.py
示例15: testGetTensorFromInfoSparse
def testGetTensorFromInfoSparse(self):
expected = array_ops.sparse_placeholder(dtypes.float32, name="x")
tensor_info = utils.build_tensor_info(expected)
actual = utils.get_tensor_from_tensor_info(tensor_info)
self.assertIsInstance(actual, sparse_tensor.SparseTensor)
self.assertEqual(expected.values.name, actual.values.name)
self.assertEqual(expected.indices.name, actual.indices.name)
self.assertEqual(expected.dense_shape.name, actual.dense_shape.name)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:8,代码来源:utils_test.py
示例16: classification_signature_def
def classification_signature_def(examples, classes, scores):
"""Creates classification signature from given examples and predictions.
Args:
examples: `Tensor`.
classes: `Tensor`.
scores: `Tensor`.
Returns:
A classification-flavored signature_def.
Raises:
ValueError: If examples is `None`.
"""
if examples is None:
raise ValueError('Classification examples cannot be None.')
if not isinstance(examples, ops.Tensor):
raise ValueError('Classification examples must be a string Tensor.')
if classes is None and scores is None:
raise ValueError('Classification classes and scores cannot both be None.')
input_tensor_info = utils.build_tensor_info(examples)
if input_tensor_info.dtype != types_pb2.DT_STRING:
raise ValueError('Classification examples must be a string Tensor.')
signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
signature_outputs = {}
if classes is not None:
classes_tensor_info = utils.build_tensor_info(classes)
if classes_tensor_info.dtype != types_pb2.DT_STRING:
raise ValueError('Classification classes must be a string Tensor.')
signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
classes_tensor_info)
if scores is not None:
scores_tensor_info = utils.build_tensor_info(scores)
if scores_tensor_info.dtype != types_pb2.DT_FLOAT:
raise ValueError('Classification scores must be a float Tensor.')
signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
scores_tensor_info)
signature_def = build_signature_def(
signature_inputs, signature_outputs,
signature_constants.CLASSIFY_METHOD_NAME)
return signature_def
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:45,代码来源:signature_def_utils_impl.py
示例17: testGetTensorFromInfoRaisesErrors
def testGetTensorFromInfoRaisesErrors(self):
expected = array_ops.placeholder(dtypes.float32, 1, name="x")
tensor_info = utils.build_tensor_info(expected)
tensor_info.name = "blah:0" # Nonexistant name.
with self.assertRaises(KeyError):
utils.get_tensor_from_tensor_info(tensor_info)
tensor_info.ClearField("name") # Malformed (missing encoding).
with self.assertRaises(ValueError):
utils.get_tensor_from_tensor_info(tensor_info)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:utils_test.py
示例18: testGetTensorFromInfoInOtherGraph
def testGetTensorFromInfoInOtherGraph(self):
with ops.Graph().as_default() as expected_graph:
expected = array_ops.placeholder(dtypes.float32, 1, name="right")
tensor_info = utils.build_tensor_info(expected)
with ops.Graph().as_default(): # Some other graph.
array_ops.placeholder(dtypes.float32, 1, name="other")
actual = utils.get_tensor_from_tensor_info(tensor_info,
graph=expected_graph)
self.assertIsInstance(actual, ops.Tensor)
self.assertIs(actual.graph, expected_graph)
self.assertEqual(expected.name, actual.name)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:11,代码来源:utils_test.py
示例19: test_load_saved_model_with_no_variables
def test_load_saved_model_with_no_variables(self, builder_cls):
"""Test that SavedModel runs saver when there appear to be no variables.
When no variables are detected, this may mean that the variables were saved
to different collections, or the collections weren't saved to the
SavedModel. If the SavedModel MetaGraphDef contains a saver, it should still
run in either of these cases.
Args:
builder_cls: SavedModelBuilder or _SavedModelBuilder class
"""
path = _get_export_dir("no_variable_saved_model")
with session.Session(graph=ops.Graph()) as sess:
x = variables.VariableV1(
5, name="x", collections=["not_global_variable"])
y = variables.VariableV1(
11, name="y", collections=["not_global_variable"])
self.assertFalse(variables._all_saveable_objects())
z = x + y
self.evaluate(variables.variables_initializer([x, y]))
foo_sig_def = signature_def_utils.build_signature_def(
{"foo_input": utils.build_tensor_info(x)},
{"foo_output": utils.build_tensor_info(z)})
builder = saved_model_builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(
sess, ["foo_graph"], {"foo": foo_sig_def},
saver=tf_saver.Saver([x, y]))
builder.save()
loader = loader_impl.SavedModelLoader(path)
with self.session(graph=ops.Graph()) as sess:
saver, _ = loader.load_graph(sess.graph, ["foo_graph"])
self.assertFalse(variables._all_saveable_objects())
self.assertIsNotNone(saver)
with self.session(graph=ops.Graph()) as sess:
loader.load(sess, ["foo_graph"])
self.assertEqual(5, sess.graph.get_tensor_by_name("x:0").eval())
self.assertEqual(11, sess.graph.get_tensor_by_name("y:0").eval())
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:41,代码来源:loader_test.py
示例20: _supervised_signature_def
def _supervised_signature_def(
method_name, inputs, loss=None, predictions=None,
metrics=None):
"""Creates a signature for training and eval data.
This function produces signatures that describe the inputs and outputs
of a supervised process, such as training or evaluation, that
results in loss, metrics, and the like. Note that this function only requires
inputs to be not None.
Args:
method_name: Method name of the SignatureDef as a string.
inputs: dict of string to `Tensor`.
loss: dict of string to `Tensor` representing computed loss.
predictions: dict of string to `Tensor` representing the output predictions.
metrics: dict of string to `Tensor` representing metric ops.
Returns:
A train- or eval-flavored signature_def.
Raises:
ValueError: If inputs or outputs is `None`.
"""
if inputs is None or not inputs:
raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
signature_inputs = {key: utils.build_tensor_info(tensor)
for key, tensor in inputs.items()}
signature_outputs = {}
for output_set in (loss, predictions, metrics):
if output_set is not None:
sig_out = {key: utils.build_tensor_info(tensor)
for key, tensor in output_set.items()}
signature_outputs.update(sig_out)
signature_def = build_signature_def(
signature_inputs, signature_outputs, method_name)
return signature_def
开发者ID:AnishShah,项目名称:tensorflow,代码行数:40,代码来源:signature_def_utils_impl.py
注:本文中的tensorflow.python.saved_model.utils.build_tensor_info函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论