def stripped_op_list_for_graph(graph_def):
"""Collect the stripped OpDefs for ops used by a graph.
This function computes the `stripped_op_list` field of `MetaGraphDef` and
similar protos. The result can be communicated from the producer to the
consumer, which can then use the C++ function
`RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
An `OpList` of ops used by the graph.
ValueError: If an unregistered op is used.
# This is the Python equivalent of StrippedOpListForGraph in C++.
# Unfortunately, since the Python op registry can differ from that in C++, we
# can't remove the duplication using swig (at least naively).
# TODO(irving): Support taking graphs directly.
used_ops = ops_used_by_graph_def(graph_def)
# Verify that all used ops are registered.
registered_ops = op_def_registry.get_registered_ops()
# These internal ops used by functions are not registered, so we need to
# whitelist them. # TODO(irving): Do something better here.
op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
for op in used_ops:
if op not in registered_ops and op not in op_whitelist:
raise ValueError("Op %s is used by the graph, but is not registered" % op)
# Build the stripped op list in sorted order
return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
def testStripDefaultAttrsInconsistentConsumerDefaults(self):
if ops._USE_C_API: return # TODO(skyewm): get this working
export_dir = self._get_export_dir(
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled. This must remove the following
# defaults for the "Complex" Op:
# o "T" : float32. (input type)
# o "Tout" : complex64. (output type)
with session.Session(graph=ops.Graph()) as sess:
real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess, ["foo"], strip_default_attrs=True)
# Save the SavedModel to disk in text format.
# Update the Op registry to remove defaults for all attrs("T", "Tout") from
# the "Complex" OpDef.
complex_op_def = op_def_registry.get_registered_ops()["Complex"]
original_complex_op_def = op_def_pb2.OpDef()
for attr_def in complex_op_def.attr:
# Loading the SavedModel via the loader must fail because the SavedModel
# does not have any attr values for the "Complex" node and the current
# op registry does not have have any default values for the "Complex" op.
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
"Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
loader.load(sess, ["foo"], export_dir)
# Update the Op registry to change the defaults for attr "Tout"
# (complex64 -> complex128).
for attr_def in complex_op_def.attr:
if == "Tout":
attr_def.default_value.type = types_pb2.DT_COMPLEX128
# Loading the SavedModel via the loader must set "Tout" attr_value for the
# "Complex" node according to the latest defaults (complex128). This is
# expected to fail the model import as there is no OpKernel registered to
# handle attrs "T" (float32) and "Tout" (complex128).
sess = session.Session(graph=ops.Graph())
with self.assertRaisesRegexp(
".*No OpKernel was registered to support Op \'Complex\' with these "
loader.load(sess, ["foo"], export_dir)
def _is_array_type_input(op, i):
registered_ops = op_def_registry.get_registered_ops()
if op not in registered_ops:
return False
op_def = registered_ops[op]
if i not in xrange(len(op_def.input_arg)):
raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg))
input_arg = op_def.input_arg[i]
return True if input_arg.number_attr else False
def _register_function_ops(func_list):
"""Registers custom ops in the default graph. This is needed
Because our checkpoint is saved with ops that are not part of Tensorflow."""
op_dict = op_def_registry.get_registered_ops()
for func in func_list:
#pylint: disable=W0212
op_def = func._definition.signature
op_dict[] = op_def
def _add_op_node(op, func):
"""Converts an op to a function def node and add it to `func`."""
node = function_pb2.FunctionDef.Node()
node.op = op.type
# pylint: disable=protected-access
if hasattr(op, "_sig"):
op_def = getattr(op, "_sig")
op_def = op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
attrs = _get_node_def_attr(op)
if not op_def.output_arg:
out_index = 0
for arg_def in op_def.output_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
_add_output_array(op, out_index, out_index + num, dtype, func))
out_index += num
elif arg_def.type_list_attr:
dtype_lst = attrs[arg_def.type_list_attr].list.type
num = len(dtype_lst)
_add_output_list(op, out_index, out_index + num, dtype_lst, func))
out_index += num
out_index += 1
inp_index = 0
for arg_def in op_def.input_arg:
if arg_def.number_attr:
dtype = arg_def.type or attrs[arg_def.type_attr].type
num = attrs[arg_def.number_attr].i
_add_input_array(op, inp_index, inp_index + num, dtype, func))
inp_index += num
elif arg_def.type_list_attr:
num = len(attrs[arg_def.type_list_attr].list.type)
for i in range(inp_index, inp_index + num)
inp_index += num
inp_index += 1
[_make_argname_from_tensor_name( for x in op.control_inputs])
for k, v in _get_node_def_attr(op).items():
def _stripped_op_list_for_graph(graph_def):
"""Returns OpDefs of ops used in graph_def."""
op_set = set()
registered_ops = op_def_registry.get_registered_ops()
for n in graph_def.node:
if n.op in registered_ops:
for func in graph_def.library.function:
for n in func.node:
if n.op in registered_ops:
return op_def_pb2.OpList(op=[registered_ops[x] for x in sorted(op_set)])
def list_registered_stateful_ops_without_inputs():
"""Returns set of registered stateful ops that do not expect inputs.
This list is used to identify the ops to be included in the state-graph and
that are subsequently fed into the apply-graphs.
A set of strings.
return set([
for name, op in op_def_registry.get_registered_ops().items()
if op.is_stateful and not op.input_arg
def _get_ref_args(self, node):
"""Determine whether an input of an op is ref-type.
node: A `NodeDef`.
A list of the arg names (as strs) that are ref-type.
op_def = op_def_registry.get_registered_ops().get(node.op)
ref_args = []
if op_def:
for i, output_arg in enumerate(op_def.output_arg):
if output_arg.is_ref:
arg_name = if i == 0 else ("%s:%d" % (, i))
return ref_args
def _strip_graph_default_valued_attrs(meta_graph_def):
"""Strips default valued attributes for node defs in given MetaGraphDef.
This method also sets `meta_info_def.stripped_default_attrs` in the given
`MetaGraphDef` proto to True.
meta_graph_def: `MetaGraphDef` protocol buffer
# Map function op names to their function definitions.
op_name_to_function = {}
for function_def in meta_graph_def.graph_def.library.function:
op_name_to_function[] = function_def
# Get all registered ops.
registered_ops = op_def_registry.get_registered_ops()
def _strip_node_default_valued_attrs(node_def):
"""Removes default valued attributes from a single node def."""
if node_def.op in op_name_to_function or node_def.op not in registered_ops:
op_def = registered_ops[node_def.op]
attrs_to_strip = set()
for attr_name, attr_value in node_def.attr.items():
if _is_default_attr_value(op_def, attr_name, attr_value):
for attr in attrs_to_strip:
del node_def.attr[attr]
# Process all NodeDef instances in graph_def.
for node_def in meta_graph_def.graph_def.node:
# Process all NodeDef instances in graph_def.library.function.
for function_def in meta_graph_def.graph_def.library.function:
for function_node_def in function_def.node_def:
# Tell consumers of this graph that default valued attrs have been stripped.
meta_graph_def.meta_info_def.stripped_default_attrs = True
def register_ops_if_needed(graph_ops):
"""Register graph ops absent in op_def_registry, if present in c++ registry.
graph_ops: set with graph op names to register.
RuntimeError: if `graph_ops` contains ops that are not in either python or
c++ registry.
missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())
if not missing_ops:
p_buffer = c_api.TF_GetAllOpList()
cpp_op_list = op_def_pb2.OpList()
cpp_registry_ops = { op for op in cpp_op_list.op}
missing_op_list = op_def_pb2.OpList()
for missing_op in missing_ops:
if missing_op not in cpp_registry_ops:
"Op %s is missing from both the python and C++ registry.",
"Adding op %s from c++ registry to python registry.",
# Note: Only raise missing op ValueError after trying to load ops.
# This allows the test to exercise all the calls into TensorFlow
# without having to write a C + python test.
if not missing_ops <= set(cpp_registry_ops.keys()):
raise RuntimeError(
"Graph ops missing from the python registry (%s) are also absent from "
"the c++ registry."
% missing_ops.difference(set(cpp_registry_ops.keys())))
def _stripped_op_list_for_graph(graph_def):
registered_ops = op_def_registry.get_registered_ops()
used_ops = {n.op for n in graph_def.node}
op_list = [registered_ops[op_name] for op_name in sorted(used_ops)]
return op_def_pb2.OpList(op=op_list)
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None, producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
Must contain an `OpDef` proto for each op type named in `graph_def`.
If omitted, uses the `OpDef` protos registered in the global registry.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided, attrs
for ops in `graph_def` that are not in `op_dict` that have their default
value according to `producer_op_list` will be removed. This will allow
some more `GraphDef`s produced by later binaries to be accepted by
earlier binaries.
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
# Type checks for inputs.
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
if input_map is None:
input_map = {}
if not (isinstance(input_map, dict)
and all(isinstance(k, compat.bytes_or_text_types)
for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
if return_elements is not None:
return_elements = tuple(return_elements)
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
# Use a canonical representation for all tensor names.
input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
used_input_keys = set()
name_to_op = {}
if op_dict is None:
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is None:
producer_op_dict = None
producer_op_dict = { op for op in producer_op_list.op}
g = ops.get_default_graph()
# Add any functions defined in `graph_def` to `g`
if graph_def.library and graph_def.library.function:
# Copy op_dict so we don't clobber the original
op_dict = copy.copy(op_dict)
# pylint: disable=protected-access
# Note that we do not prepend `name` to the function name. The reasoning is
# that function names are similar to op definition names, which currently do
# not have a scoped name or namespace scheme.
functions = function._from_library(graph_def.library)
for f in functions:
op_dict[] = f.definition.signature
# pylint: enable=protected-access
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None, producer_op_list=None):
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) Deprecated, do not use.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided,
unrecognized attrs for ops in `graph_def` that have their default value
according to `producer_op_list` will be removed. This will allow some more
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
graph_def = _ProcessGraphDefParam(graph_def)
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
op_dict = op_def_registry.get_registered_ops()
if producer_op_list is not None:
# TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
graph = ops.get_default_graph()
if graph._c_graph: # pylint: disable=protected-access
with ops.name_scope(name, 'import', input_map.values()) as scope:
# Save unique prefix generated by name_scope
if scope:
assert scope.endswith('/')
prefix = scope[:-1]
prefix = ''
# Generate any input map tensors inside name scope
input_map = _ConvertInputMapValues(name, input_map)
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
with errors.raise_exception_on_not_ok_status() as status:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options, status) # pylint: disable=protected-access
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
# Create _DefinedFunctions for any imported functions.
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
if graph_def.library and graph_def.library.function:
# pylint: disable=protected-access
functions = function._from_library(graph_def.library)
for f in functions:
# pylint: enable=protected-access
# Treat input mappings that don't appear in the graph as an error, because
def import_graph_def(graph_def, input_map=None, return_elements=None,
name=None, op_dict=None):
"""Imports the TensorFlow graph in `graph_def` into the Python `Graph`.
This function provides a way to import a serialized TensorFlow
protocol buffer, and extract individual objects in the `GraphDef` as
[`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
[`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
`GraphDef` proto.
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Defaults to `"import"`.
op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
Must contain an `OpDef` proto for each op type named in `graph_def`.
If omitted, uses the `OpDef` protos registered in the global registry.
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
# Type checks for inputs.
if not isinstance(graph_def, graph_pb2.GraphDef):
# `graph_def` could be a dynamically-created message, so try a duck-typed
# approach
old_graph_def = graph_def
graph_def = graph_pb2.GraphDef()
except TypeError:
raise TypeError('graph_def must be a GraphDef proto.')
if input_map is None:
input_map = {}
if not (isinstance(input_map, dict)
and all(isinstance(k, compat.bytes_or_text_types)
for k in input_map.keys())):
raise TypeError('input_map must be a dictionary mapping strings to '
'Tensor objects.')
if return_elements is not None:
return_elements = tuple(return_elements)
if not all(isinstance(x, compat.bytes_or_text_types)
for x in return_elements):
raise TypeError('return_elements must be a list of strings.')
# Use a canonical representation for all tensor names.
input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
used_input_keys = set()
name_to_op = {}
if op_dict is None:
op_dict = op_def_registry.get_registered_ops()
with ops.op_scope(input_map.values(), name, 'import'):
g = ops.get_default_graph()
with ops.name_scope('_inputs'):
input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
# NOTE(mrry): We do this in two passes, because there may be a cycle in
# `graph_def`.
# 1. Add operations without their inputs.
for node in graph_def.node:
# Set any default attr values that aren't present.
op_def = op_dict[node.op]
for attr_def in op_def.attr:
key =
if attr_def.HasField('default_value'):
value = node.attr[key]
if value is None or value.WhichOneof('value') is None:
output_types = _OutputTypes(node, op_dict)
name_to_op[] = g.create_op(
node.op, [], output_types,, attrs=node.attr,
compute_shapes=False, compute_device=False,
# 2. Add inputs to the operations.
for node in graph_def.node:
def import_graph_def(graph_def,
"""Imports the graph from `graph_def` into the current default `Graph`.
This function provides a way to import a serialized TensorFlow
protocol buffer, and extract individual objects in the `GraphDef` as
@{tf.Tensor} and @{tf.Operation} objects. Once extracted,
these objects are placed into the current default `Graph`. See
@{tf.Graph.as_graph_def} for a way to create a `GraphDef`
graph_def: A `GraphDef` proto containing operations to be imported into
the default graph.
input_map: A dictionary mapping input names (as strings) in `graph_def`
to `Tensor` objects. The values of the named input tensors in the
imported graph will be re-mapped to the respective `Tensor` values.
return_elements: A list of strings containing operation names in
`graph_def` that will be returned as `Operation` objects; and/or
tensor names in `graph_def` that will be returned as `Tensor` objects.
name: (Optional.) A prefix that will be prepended to the names in
`graph_def`. Note that this does not apply to imported function names.
Defaults to `"import"`.
op_dict: (Optional.) Deprecated, do not use.
producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
list of `OpDef`s used by the producer of the graph. If provided,
unrecognized attrs for ops in `graph_def` that have their default value
according to `producer_op_list` will be removed. This will allow some more
`GraphDef`s produced by later binaries to be accepted by earlier binaries.
A list of `Operation` and/or `Tensor` objects from the imported graph,
corresponding to the names in `return_elements`.
TypeError: If `graph_def` is not a `GraphDef` proto,
`input_map` is not a dictionary mapping strings to `Tensor` objects,
or `return_elements` is not a list of strings.
ValueError: If `input_map`, or `return_elements` contains names that
do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
it refers to an unknown tensor).
op_dict = op_def_registry.get_registered_ops()
graph_def = _ProcessGraphDefParam(graph_def, op_dict)
input_map = _ProcessInputMapParam(input_map)
return_elements = _ProcessReturnElementsParam(return_elements)
if producer_op_list is not None:
# TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
_RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)
graph = ops.get_default_graph()
with ops.name_scope(name, 'import', input_map.values()) as scope:
# Save unique prefix generated by name_scope
if scope:
assert scope.endswith('/')
prefix = scope[:-1]
prefix = ''
# Generate any input map tensors inside name scope
input_map = _ConvertInputMapValues(name, input_map)
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
_PopulateTFImportGraphDefOptions(options, prefix, input_map,
# _ProcessNewOps mutates the new operations. _mutation_lock ensures a
# call cannot occur between creating the TF_Operations in the
# TF_GraphImportGraphDefWithResults call and mutating the them in
# _ProcessNewOps.
with graph._mutation_lock(): # pylint: disable=protected-access
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
results = c_api.TF_GraphImportGraphDefWithResults(
graph._c_graph, serialized, options) # pylint: disable=protected-access
results = c_api_util.ScopedTFImportGraphDefResults(results)
except errors.InvalidArgumentError as e:
# Convert to ValueError for backwards compatibility.
raise ValueError(str(e))
# Create _DefinedFunctions for any imported functions.
# We do this by creating _DefinedFunctions directly from `graph_def`, and
# adding them to `graph`. Adding an existing function to a TF_Graph is a
# no-op, so this only has the effect of updating the Python state (usually
# _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
# TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
# TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
# TODO(b/74620627): move this after _ProcessNewOps outside the lock once
# _USE_C_SHAPES is removed.
if graph_def.library and graph_def.library.function:
def function_def_to_graph_def(fdef, input_shapes=None):
"""Convert a FunctionDef to a GraphDef.
1. Creates placeholder nodes corresponding to inputs in
2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
3. Renames inputs of all nodes to use the convention of GraphDef instead of
FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
in FunctionDefs is different from GraphDefs.
fdef: FunctionDef.
input_shapes: Optional. A list of TensorShape objects of the shapes of
function inputs. If specified, its length must match length of
`fdef.signature.input_arg`. If a shape is None, the corresponding input
placeholder will have unknown shape.
A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
from nested tensor names (in FunctionDef) to flattened names (in GraphDef).
ValueError: If the length of input_shapes does not match the number of
input_args or if the FunctionDef is invalid.
graph_def = graph_pb2.GraphDef()
if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
raise ValueError("Length of input_shapes must match the number of " +
"input_args. len(input_shapes): {} len(input_arg): {}".
format(len(input_shapes), len(fdef.signature.input_arg)))
# 1. Create placeholders for input nodes.
for i, arg_def in enumerate(fdef.signature.input_arg):
node_def = graph_def.node.add() =
node_def.op = "Placeholder"
node_def.attr["dtype"].type = arg_def.type
if input_shapes and input_shapes[i] is not None:
# 2. Copy all body NodeDefs to the GraphDef.
# 3. Perform the renaming.
# Build the tensor name mapping then flatten the tensor names.
# See comment on `FunctionDef.node_def` on how the tensor naming in
# FunctionDefs is different from GraphDefs.
nested_to_flat_tensor_name = {}
for arg_def in fdef.signature.input_arg:
nested_to_flat_tensor_name[] = "{}:0".format(
for node_def in fdef.node_def:
op_def = op_def_registry.get_registered_ops().get(node_def.op)
if not op_def:
# TODO(b/80470245): Support functions which refer other functions.
raise NotImplementedError(
"No op registered for {},".format(node_def.op) +
" it may be a function. function_def_to_graph_def " +
"currently does not support converting functions with " +
"references to other graph functions.")
for attr in op_def.attr:
if attr.type in ("func", "list(func)"):
# TODO(b/80470245): Support functions which refer other functions.
raise NotImplementedError("Unsupported attr {} ".format( +
" with type {}".format(attr.type) +
" in op {}. ".format( +
"function_def_to_graph_def currently does " +
"not support converting functions with " +
"references to other graph functions.")
# Iterate over output_args in op_def to build the map.
# Index of the output tensor in the flattened list of *all* output
# tensors of the op.
flattened_index = 0
for arg_def in op_def.output_arg:
num_args = _get_num_args(arg_def, node_def)
for i in range(num_args):
# Map tensor names from "node_name:output_arg_name:index" to
# "node_name:flattened_index".
nested_name = "{}:{}:{}".format(,, i)
flat_name = "{}:{}".format(, flattened_index)
nested_to_flat_tensor_name[nested_name] = flat_name
flattened_index += 1
# Update inputs of all nodes in graph.
for node_def in graph_def.node:
for i in range(len(node_def.input)):
node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]
return graph_def, nested_to_flat_tensor_name
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top
_REGISTERED_OPS = op_def_registry.get_registered_ops()
def enable_jit_nonstateful(node_def):
return not _REGISTERED_OPS[node_def.op].is_stateful
except KeyError:
raise ValueError("Unregistered op being created: %s" % node_def)
class JITTest(test.TestCase):
def compute(self, use_jit, compute_fn):
with self.test_session(graph=ops.Graph()) as sess:
with jit.experimental_jit_scope(use_jit):