本文整理汇总了Python中tensorflow.python.util.compat.as_str函数的典型用法代码示例。如果您正苦于以下问题:Python as_str函数的具体用法?Python as_str怎么用?Python as_str使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了as_str函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: meta_graph_transform
def meta_graph_transform(
base_meta_graph_def, input_names, output_names, transforms, tags,
checkpoint_path=None):
"""Apply the Graph Transform tool to a MetaGraphDef.
Args:
base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
input_names: Names of input nodes.
output_names: Names of output nodes.
transforms: A list of strings naming the graph transforms to be applied in
order. These transform names are exactly those supported by the Graph
Transform Tool, with the addition of the 'freeze_graph' transform.
tags: A list of tags with which to annotate the transformed MetaGraphDef.
checkpoint_path: A path to a checkpoint to restore during freezing,
if needed (default None).
Returns:
A new transformed MetaGraphDef protocol buffer.
"""
meta_graph_def = _meta_graph_pb2.MetaGraphDef()
initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)
transformed_graph_def = _do_transforms(
base_meta_graph_def.graph_def,
input_names,
output_names,
initializer_names,
transforms,
base_meta_graph_def.saver_def,
checkpoint_path)
meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
meta_graph_def.meta_info_def.ClearField('tags')
for tag in tags:
meta_graph_def.meta_info_def.tags.append(tag)
base_op_names = [compat.as_str(node.name)
for node in base_meta_graph_def.graph_def.node]
retained_op_names = [compat.as_str(node.name)
for node in meta_graph_def.graph_def.node]
removed_op_names = set(base_op_names) - set(retained_op_names)
# Copy saver, excluding any pruned nodes
_add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)
# Copy collections, excluding any pruned nodes
for collection_name in base_meta_graph_def.collection_def:
_add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name,
removed_op_names)
# Copy signature_defs, excluding any pruned nodes
for signature_name in base_meta_graph_def.signature_def:
_add_pruned_signature(
base_meta_graph_def, meta_graph_def, signature_name,
removed_op_names)
return meta_graph_def
开发者ID:Dr4KK,项目名称:tensorflow,代码行数:60,代码来源:meta_graph_transform.py
示例2: _PopulateTFImportGraphDefOptions
def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
return_elements):
"""Populates the TF_ImportGraphDefOptions `options`."""
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
c_api.TF_ImportGraphDefOptionsSetUniquifyPrefix(options, True)
for input_src, input_dst in input_map.items():
input_src = compat.as_str(input_src)
if input_src.startswith('^'):
src_name = compat.as_bytes(input_src[1:])
dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access
c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name,
dst_op)
else:
src_name, src_idx = _ParseTensorName(input_src)
src_name = compat.as_str(src_name)
dst_output = input_dst._as_tf_output() # pylint: disable=protected-access
c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name,
src_idx, dst_output)
for name in return_elements or []:
if ':' in name:
op_name, index = _ParseTensorName(name)
op_name = compat.as_str(op_name)
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
else:
c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
compat.as_str(name))
开发者ID:andrewharp,项目名称:tensorflow,代码行数:28,代码来源:importer.py
示例3: _init_from_proto
def _init_from_proto(self, hparam_def):
"""Creates a new HParams from `HParamDef` protocol buffer.
Args:
hparam_def: `HParamDef` protocol buffer.
"""
assert isinstance(hparam_def, hparam_pb2.HParamDef)
for name, value in hparam_def.hparam.items():
kind = value.WhichOneof('kind')
if kind.endswith('_value'):
# Single value.
if kind.startswith('int64'):
# Setting attribute value to be 'int' to ensure the type is compatible
# with both Python2 and Python3.
self.add_hparam(name, int(getattr(value, kind)))
elif kind.startswith('bytes'):
# Setting attribute value to be 'str' to ensure the type is compatible
# with both Python2 and Python3. UTF-8 encoding is assumed.
self.add_hparam(name, compat.as_str(getattr(value, kind)))
else:
self.add_hparam(name, getattr(value, kind))
else:
# List of values.
if kind.startswith('int64'):
# Setting attribute value to be 'int' to ensure the type is compatible
# with both Python2 and Python3.
self.add_hparam(name, [int(v) for v in getattr(value, kind).value])
elif kind.startswith('bytes'):
# Setting attribute value to be 'str' to ensure the type is compatible
# with both Python2 and Python3. UTF-8 encoding is assumed.
self.add_hparam(
name, [compat.as_str(v) for v in getattr(value, kind).value])
else:
self.add_hparam(name, [v for v in getattr(value, kind).value])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:34,代码来源:hparam.py
示例4: _ProcessReturnElementsParam
def _ProcessReturnElementsParam(return_elements):
"""Type-checks and possibly canonicalizes `return_elements`."""
if return_elements is None: return None
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
return tuple(compat.as_str(x) for x in return_elements)
开发者ID:andrewharp,项目名称:tensorflow,代码行数:7,代码来源:importer.py
示例5: _clean_save_and_restore
def _clean_save_and_restore(graph_def, op, removed_op_names):
"""Clean the specified save and restore op.
Updates the dtypes attribute of the save / restore op and the associated name
and shape tensors to remove entries for variables that have been removed.
Args:
graph_def: A GraphDef proto to be transformed.
op: The save or restore op to update.
removed_op_names: List of op names that have been removed.
"""
name = op.name + '/tensor_names'
shape = op.name + '/shape_and_slices'
name_op = _find_op(graph_def, name)
shape_op = _find_op(graph_def, shape)
name_op_value_tensor = name_op.attr['value'].tensor
shape_op_value_tensor = shape_op.attr['value'].tensor
names = []
shapes = []
dtypes = []
for index, value in enumerate(name_op_value_tensor.string_val):
if not _is_removed(compat.as_str(value), removed_op_names):
names.append(value)
shapes.append(shape_op_value_tensor.string_val[index])
dtypes.append(op.attr['dtypes'].list.type[index])
name_op_value_tensor.string_val[:] = names
name_op_value_tensor.tensor_shape.dim[0].size = len(names)
shape_op_value_tensor.string_val[:] = shapes
shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
op.attr['dtypes'].list.type[:] = dtypes
name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
开发者ID:Crazyonxh,项目名称:tensorflow,代码行数:33,代码来源:meta_graph_transform.py
示例6: encode_arg
def encode_arg(arg, path):
"""A representation for this argument, for converting into signatures."""
if isinstance(arg, ops.Tensor):
user_specified_name = None
try:
user_specified_name = compat.as_str(
arg.op.get_attr("_user_specified_name"))
except ValueError:
pass
if path and user_specified_name and user_specified_name != path[0]:
# The user has explicitly named the argument differently than the name
# of the function argument.
name = user_specified_name
else:
name = "/".join([str(p) for p in path])
return tensor_spec.TensorSpec(arg.shape, arg.dtype, name)
if isinstance(arg, (
int,
float,
bool,
type(None),
dtypes.DType,
tensor_spec.TensorSpec,
)):
return arg
return UnknownArgument()
开发者ID:kylin9872,项目名称:tensorflow,代码行数:27,代码来源:func_graph.py
示例7: assert_equal_graph_def
def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
"""Asserts that two `GraphDef`s are (mostly) the same.
Compares two `GraphDef` protos for equality, ignoring versions and ordering of
nodes, attrs, and control inputs. Node names are used to match up nodes
between the graphs, so the naming of nodes must be consistent.
Args:
actual: The `GraphDef` we have.
expected: The `GraphDef` we expected.
checkpoint_v2: boolean determining whether to ignore randomized attribute
values that appear in V2 checkpoints.
Raises:
AssertionError: If the `GraphDef`s do not match.
TypeError: If either argument is not a `GraphDef`.
"""
if not isinstance(actual, graph_pb2.GraphDef):
raise TypeError("Expected tf.GraphDef for actual, got %s" %
type(actual).__name__)
if not isinstance(expected, graph_pb2.GraphDef):
raise TypeError("Expected tf.GraphDef for expected, got %s" %
type(expected).__name__)
if checkpoint_v2:
_strip_checkpoint_v2_randomized(actual)
_strip_checkpoint_v2_randomized(expected)
diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
expected.SerializeToString())
if diff:
raise AssertionError(compat.as_str(diff))
开发者ID:LUTAN,项目名称:tensorflow,代码行数:32,代码来源:test_util.py
示例8: _create_new_tf_function
def _create_new_tf_function(func_graph):
"""Converts func_graph to a TF_Function and adds it to the current graph.
Args:
func_graph: function._FuncGraph
Returns:
The name of the new TF_Function.
"""
c_func = c_api.TF_GraphToFunction_wrapper(
func_graph._c_graph,
compat.as_str(func_graph.name),
False, # append_hash_to_fn_name
None, # opers
[t._as_tf_output() for t in func_graph.inputs],
[t._as_tf_output() for t in func_graph.outputs],
[],
None, # opts
None) # description
_ = c_api_util.ScopedTFFunction(c_func)
# TODO(b/109833212): this sucks, we're serializing the TF_Function*,
# deserializing it into a Python FunctionDef, then reserializing it to create
# a new TF_Function that we add to the graph.
fdef = _function.function_def_from_tf_function(c_func)
defined_func = _function._from_definition(fdef)
defined_func._sub_functions = func_graph._functions
defined_func.add_to_graph(func_graph._outer_graph)
return func_graph.name
开发者ID:godyd2702,项目名称:tensorflow,代码行数:30,代码来源:cond_v2_impl.py
示例9: _node_def
def _node_def(from_node_def, export_scope, unbound_inputs, clear_devices=False):
"""Create a `NodeDef` proto with export_scope stripped.
Args:
from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
export_scope: A `string` representing the name scope to remove.
unbound_inputs: An array of unbound input names if they exist.
clear_devices: Boolean which controls whether to clear device information
from node_def. Default false.
Returns:
A `node_def_pb2.NodeDef` protocol buffer.
"""
node_def = copy.deepcopy(from_node_def)
for i, v in enumerate(node_def.input):
if (export_scope and
not node_def.input[i].lstrip("^").startswith(export_scope)):
# Adds "$unbound_inputs_" prefix to the unbound name so they are easily
# identifiable.
node_def.input[i] = re.sub(r"([\^]|^)(.*)",
r"\1" + _UNBOUND_INPUT_PREFIX + r"\2",
compat.as_str(v))
unbound_inputs.append(node_def.input[i])
else:
node_def.input[i] = ops.strip_name_scope(v, export_scope)
node_def.name = compat.as_bytes(
ops.strip_name_scope(from_node_def.name, export_scope))
for k, v in six.iteritems(from_node_def.attr):
if k == "_class":
new_s = [compat.as_bytes(
ops.strip_name_scope(s, export_scope)) for s in v.list.s
if not export_scope or
compat.as_str(s).split("@")[1].startswith(export_scope)]
node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
elif node_def.op in ("Enter", "RefEnter") and k == "frame_name":
if not export_scope or compat.as_str(v.s).startswith(export_scope):
new_s = compat.as_bytes(ops.strip_name_scope(v.s, export_scope))
node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(s=new_s))
else:
node_def.attr[k].CopyFrom(v)
if clear_devices:
node_def.device = ""
return node_def
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:46,代码来源:meta_graph.py
示例10: __init__
def __init__(self, name, graph, operations, inputs, outputs, attrs):
"""Initializes an eager defined function.
Args:
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
operations: list of Operation; the subset of operations in the graph
which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
attrs: dict mapping names of attributes to their AttrValue values
"""
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
False,
[o._c_op for o in operations], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
[],
None,
compat.as_str(""))
for name, attr_value in attrs.items():
serialized = attr_value.SerializeToString()
# TODO(iga): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use status.
pywrap_tensorflow.TF_FunctionSetAttrValueProto(
fn, compat.as_str(name), serialized)
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
if context.executing_eagerly():
_register(fn)
self.definition = function_def
self.name = function_def.signature.name
self.signature = function_def.signature
self.grad_func_name = None
self.python_grad_func = None
self._c_func = c_api_util.ScopedTFFunction(fn)
self._grad_func = None
开发者ID:KiaraStarlab,项目名称:tensorflow,代码行数:46,代码来源:function.py
示例11: save
def save(self, sess, save_path, global_step=None, latest_filename=None):
"""Saves variables.
This method runs the ops added by the constructor for saving variables.
It requires a session in which the graph was launched. The variables to
save must also have been initialized.
The method returns the path of the newly created checkpoint file. This
path can be passed directly to a call to `restore()`.
Args:
sess: A Session to use to save the variables.
save_path: String. Path to the checkpoint filename. If the saver is
`sharded`, this is the prefix of the sharded checkpoint filename.
global_step: If provided the global step number is appended to
`save_path` to create the checkpoint filename. The optional argument
can be a `Tensor`, a `Tensor` name or an integer.
latest_filename: Optional name for the protocol buffer file that will
contains the list of most recent checkpoint filenames. That file,
kept in the same directory as the checkpoint files, is automatically
managed by the saver to keep track of recent checkpoints. Defaults to
'checkpoint'.
Returns:
A string: path at which the variables were saved. If the saver is
sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
is the number of shards created.
Raises:
TypeError: If `sess` is not a `Session`.
ValueError: If `latest_filename` contains path components.
"""
if latest_filename is None:
latest_filename = "checkpoint"
if os.path.split(latest_filename)[0]:
raise ValueError("'latest_filename' must not contain path components")
if global_step is not None:
if not isinstance(global_step, compat.integral_types):
global_step = training_util.global_step(sess, global_step)
checkpoint_file = "%s-%d" % (save_path, global_step)
else:
checkpoint_file = save_path
save_path = os.path.dirname(save_path)
if not isinstance(sess, session.SessionInterface):
raise TypeError("'sess' must be a Session; %s" % sess)
model_checkpoint_path = sess.run(
self._save_tensor_name, {self._filename_tensor_name: checkpoint_file})
model_checkpoint_path = compat.as_str(model_checkpoint_path)
self._MaybeDeleteOldCheckpoints(model_checkpoint_path)
update_checkpoint_state(save_path, model_checkpoint_path,
self.last_checkpoints, latest_filename)
return model_checkpoint_path
开发者ID:hessenh,项目名称:Human-Activity-Recognition,代码行数:55,代码来源:saver.py
示例12: __init__
def __init__(self, name, graph, operations, inputs, outputs):
"""Initializes an eager defined function.
Args:
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
operations: list of Operation; the subset of operations in the graph
which will be in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
"""
with errors.raise_exception_on_not_ok_status() as status:
fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
graph._c_graph, # pylint: disable=protected-access
compat.as_str(name),
False,
[o._c_op for o in operations], # pylint: disable=protected-access
[t._as_tf_output() for t in inputs], # pylint: disable=protected-access
[t._as_tf_output() for t in outputs], # pylint: disable=protected-access
[],
None,
compat.as_str(""),
status)
# TODO(apassos) avoid creating a FunctionDef (specially to grab the
# signature, but also in general it's nice not to depend on it.
with c_api_util.tf_buffer() as buffer_:
with errors.raise_exception_on_not_ok_status() as status:
pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
function_def = function_pb2.FunctionDef()
function_def.ParseFromString(compat.as_bytes(proto_data))
if context.executing_eagerly():
_register(fn)
self.definition = function_def
self.name = function_def.signature.name
self.signature = function_def.signature
self.grad_func_name = None
self.python_grad_func = None
self._c_func = fn
self._grad_func = None
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:40,代码来源:function.py
示例13: request_stop
def request_stop(self, ex=None):
"""Request that the threads stop.
After this is called, calls to `should_stop()` will return `True`.
Args:
ex: Optional `Exception`, or Python `exc_info` tuple as returned by
`sys.exc_info()`. If this is the first call to `request_stop()` the
corresponding exception is recorded and re-raised from `join()`.
"""
with self._lock:
if not self._stop_event.is_set():
if ex and self._exc_info_to_raise is None:
if isinstance(ex, tuple):
logging.info("Error reported to Coordinator: %s",
compat.as_str(unicode(ex[1])))
self._exc_info_to_raise = ex
else:
logging.info("Error reported to Coordinator: %s",
compat.as_str(unicode(ex)))
self._exc_info_to_raise = sys.exc_info()
self._stop_event.set()
开发者ID:peace195,项目名称:tensorflow,代码行数:22,代码来源:coordinator.py
示例14: _set_c_attrs
def _set_c_attrs(self, attrs):
"""Sets `attrs` as attributes of self._c_func.
Requires that self._c_func is not None.
Args:
attrs: a dictionary from attribute name to attribute proto value
"""
for name, attr_value in attrs.items():
serialized = attr_value.SerializeToString()
# TODO(skyewm): this creates and deletes a new TF_Status for every attr.
# It might be worth creating a convenient way to re-use the same status.
c_api.TF_FunctionSetAttrValueProto(self._c_func.func, compat.as_str(name),
serialized)
开发者ID:didukhle,项目名称:tensorflow,代码行数:14,代码来源:function.py
示例15: _ReadAndCheckRowsUsingFeatures
def _ReadAndCheckRowsUsingFeatures(self, num_rows):
self.server.handler.num_rows = num_rows
with self.test_session() as sess:
feature_configs = {
"int64_col":
parsing_ops.FixedLenFeature(
[1], dtype=dtypes.int64),
"string_col":
parsing_ops.FixedLenFeature(
[1], dtype=dtypes.string, default_value="s_default"),
}
reader = cloud.BigQueryReader(
project_id=_PROJECT,
dataset_id=_DATASET,
table_id=_TABLE,
num_partitions=4,
features=feature_configs,
timestamp_millis=1,
test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
self.server.httpd.server_address[1])))
key, value = _SetUpQueue(reader)
seen_rows = []
features = parsing_ops.parse_example(
array_ops.reshape(value, [1]), feature_configs)
for _ in range(num_rows):
int_value, str_value = sess.run(
[features["int64_col"], features["string_col"]])
# Parse values returned from the session.
self.assertEqual(int_value.shape, (1, 1))
self.assertEqual(str_value.shape, (1, 1))
int64_col = int_value[0][0]
string_col = str_value[0][0]
seen_rows.append(int64_col)
# Compare.
expected_row = _ROWS[int64_col]
self.assertEqual(int64_col, expected_row[0])
self.assertEqual(
compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
else "s_default")
self.assertItemsEqual(seen_rows, range(num_rows))
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
sess.run([key, value])
开发者ID:brainwy12,项目名称:tensorflow,代码行数:50,代码来源:bigquery_reader_ops_test.py
示例16: _node_def
def _node_def(from_node_def, export_scope, unbound_inputs):
"""Create a `NodeDef` proto with export_scope stripped.
Args:
from_node_def: A `node_def_pb2.NodeDef` protocol buffer.
export_scope: A `string` representing the name scope to remove.
unbound_inputs: An array of unbound input names if they exist.
Returns:
A `node_def_pb2.NodeDef` protocol buffer.
"""
node_def = copy.deepcopy(from_node_def)
for i, v in enumerate(node_def.input):
if (export_scope and
not node_def.input[i].lstrip("^").startswith(export_scope)):
# Adds "$unbound_inputs_" prefix to the unbound name so they are easily
# identifiable.
node_def.input[i] = re.sub(r"([\^]|^)(.*)", r"\1$unbound_inputs_\2",
compat.as_str(v))
unbound_inputs.append(node_def.input[i])
else:
node_def.input[i] = ops.strip_name_scope(v, export_scope)
node_def.name = compat.as_bytes(
ops.strip_name_scope(from_node_def.name, export_scope))
for k, v in six.iteritems(from_node_def.attr):
if k == "_class":
new_s = [compat.as_bytes(
ops.strip_name_scope(s, export_scope)) for s in v.list.s
if not export_scope or
compat.as_str(s).split("@")[1].startswith(export_scope)]
node_def.attr[k].CopyFrom(attr_value_pb2.AttrValue(
list=attr_value_pb2.AttrValue.ListValue(s=new_s)))
else:
node_def.attr[k].CopyFrom(v)
return node_def
开发者ID:821760408-sp,项目名称:tensorflow,代码行数:36,代码来源:meta_graph.py
示例17: _GetColocationNames
def _GetColocationNames(op):
"""Returns names of the ops that `op` should be colocated with."""
colocation_names = []
try:
class_values = op.get_attr('_class')
except ValueError:
# No _class attr
return
for val in class_values:
val = compat.as_str(val)
if val.startswith('loc:@'):
colocation_node_name = val[len('loc:@'):]
if colocation_node_name != op.name:
colocation_names.append(colocation_node_name)
return colocation_names
开发者ID:andrewharp,项目名称:tensorflow,代码行数:15,代码来源:importer.py
示例18: lookup
def lookup(self, name):
"""Looks up "name".
Args:
name: a string specifying the registry key for the candidate.
Returns:
Registered object if found
Raises:
LookupError: if "name" has not been registered.
"""
name = compat.as_str(name)
if name in self._registry:
return self._registry[name][_TYPE_TAG]
else:
raise LookupError(
"%s registry has no entry for: %s" % (self._name, name))
开发者ID:chengyang317,项目名称:information_pursuit,代码行数:15,代码来源:registry.py
示例19: canonicalize_signatures
def canonicalize_signatures(signatures):
"""Converts `signatures` into a dictionary of concrete functions."""
if signatures is None:
return {}
if not isinstance(signatures, collections.Mapping):
signatures = {
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
concrete_signatures = {}
for signature_key, function in signatures.items():
signature_function = _get_signature(function)
if signature_function is None:
raise ValueError(
("Expected a TensorFlow function to generate a signature for, but "
"got {}. Only `tf.functions` with an input signature or "
"concrete functions can be used as a signature.").format(function))
# Re-wrap the function so that it returns a dictionary of Tensors. This
# matches the format of 1.x-style signatures.
# pylint: disable=cell-var-from-loop
@def_function.function
def signature_wrapper(**kwargs):
structured_outputs = signature_function(**kwargs)
return _normalize_outputs(
structured_outputs, signature_function.name, signature_key)
# TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
# always match keyword arguments.
tensor_spec_signature = {}
for keyword, tensor in zip(
signature_function._arg_keywords, # pylint: disable=protected-access
signature_function.inputs):
keyword = compat.as_str(keyword)
tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
tensor, name=keyword)
final_concrete = signature_wrapper.get_concrete_function(
**tensor_spec_signature)
# pylint: disable=protected-access
if len(final_concrete._arg_keywords) == 1:
# If there is only one input to the signature, a very common case, then
# ordering is unambiguous and we can let people pass a positional
# argument. Since SignatureDefs are unordered (protobuf "map") multiple
# arguments means we need to be keyword-only.
final_concrete._num_positional_args = 1
else:
final_concrete._num_positional_args = 0
# pylint: enable=protected-access
concrete_signatures[signature_key] = final_concrete
# pylint: enable=cell-var-from-loop
return concrete_signatures
开发者ID:aritratony,项目名称:tensorflow,代码行数:48,代码来源:signature_serialization.py
示例20: initialize_tpu_system
def initialize_tpu_system(cluster_resolver=None):
"""Initialize the TPU devices in a separate session and graph.
Args:
cluster_resolver: A tf.contrib.cluster_resolver.TPUClusterResolver,
which provides information about the TPU cluster.
Returns:
The tf.contrib.tpu.Topology object for the topology of the TPU cluster.
"""
if cluster_resolver is None:
cluster_resolver = TPUClusterResolver("")
master = cluster_resolver.master()
logging.info("Initializing the TPU system.")
if context.executing_eagerly():
# This function looks as it is for the following non-intuitive reasons.
# tpu.initialize_system creates a dummy op whose sole purpose is to trigger
# DistributedTPURewritePass. This pass actually adds real ops that
# initialize the TPU system. Thus, we can't simply run tpu.initialize_system
# eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
# The easiest way to trigger a rewrite is to run the function with
# TPUPartitionedCallOp.
@function.defun
def _tpu_init_fn():
return tpu.initialize_system()
# We can't call _tpu_init_fn normally (because it contains just a dummy op,
# see above) but need to define it to get it added to eager context
# and get its assigned name.
# pylint: disable=protected-access
graph_func = _tpu_init_fn._get_concrete_function_internal()
func_name = compat.as_str(graph_func._inference_function.name)
# pylint: enable=protected-access
output = tpu_functional_ops.TPUPartitionedCall(
args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
serialized_topology = output[0].numpy()
else:
session_config = config_pb2.ConfigProto(allow_soft_placement=True)
with ops.Graph().as_default():
with session_lib.Session(config=session_config, target=master) as sess:
serialized_topology = sess.run(tpu.initialize_system())
logging.info("Finished initializing TPU system.")
return topology.Topology(serialized=serialized_topology)
开发者ID:jackd,项目名称:tensorflow,代码行数:46,代码来源:tpu_strategy.py
注:本文中的tensorflow.python.util.compat.as_str函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论