• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python op_def_registry.get_registered_ops函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.op_def_registry.get_registered_ops函数的典型用法代码示例。如果您正苦于以下问题:Python get_registered_ops函数的具体用法?Python get_registered_ops怎么用?Python get_registered_ops使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了get_registered_ops函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: stripped_op_list_for_graph

def stripped_op_list_for_graph(graph_def):
    """Collect the stripped OpDefs for ops used by a graph.

  This function computes the `stripped_op_list` field of `MetaGraphDef` and
  similar protos.  The result can be communicated from the producer to the
  consumer, which can then use the C++ function
  `RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.

  Args:
    graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.

  Returns:
    An `OpList` of ops used by the graph.

  Raises:
    ValueError: If an unregistered op is used.
  """
    # This is the Python equivalent of StrippedOpListForGraph in C++.
    # Unfortunately, since the Python op registry can differ from that in C++, we
    # can't remove the duplication using swig (at least naively).
    # TODO(irving): Support taking graphs directly.

    used_ops = ops_used_by_graph_def(graph_def)

    # Verify that all used ops are registered.
    registered_ops = op_def_registry.get_registered_ops()
    # These internal ops used by functions are not registered, so we need to
    # whitelist them.  # TODO(irving): Do something better here.
    op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
    for op in used_ops:
        if op not in registered_ops and op not in op_whitelist:
            raise ValueError("Op %s is used by the graph, but is not registered" % op)

    # Build the stripped op list in sorted order
    return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops) if op in registered_ops])
开发者ID:rhuangq,项目名称:tensorflow,代码行数:35,代码来源:meta_graph.py


示例2: testStripDefaultAttrsInconsistentConsumerDefaults

  def testStripDefaultAttrsInconsistentConsumerDefaults(self):
    if ops._USE_C_API: return  # TODO(skyewm): get this working

    export_dir = self._get_export_dir(
        "test_strip_default_attrs_no_consumer_defaults")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    # Add a graph with two float32 variables and a Complex Op composing them
    # with strip_default_attrs enabled. This must remove the following
    # defaults for the "Complex" Op:
    #   o "T"    : float32.   (input type)
    #   o "Tout" : complex64. (output type)
    with session.Session(graph=ops.Graph()) as sess:
      real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
      imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
      math_ops.complex(real_num, imag_num, name="complex")
      sess.run(variables.global_variables_initializer())
      builder.add_meta_graph_and_variables(
          sess, ["foo"], strip_default_attrs=True)

    # Save the SavedModel to disk in text format.
    builder.save(as_text=True)

    # Update the Op registry to remove defaults for all attrs("T", "Tout") from
    # the "Complex" OpDef.
    complex_op_def = op_def_registry.get_registered_ops()["Complex"]
    original_complex_op_def = op_def_pb2.OpDef()
    original_complex_op_def.CopyFrom(complex_op_def)
    for attr_def in complex_op_def.attr:
      attr_def.ClearField("default_value")

    # Loading the SavedModel via the loader must fail because the SavedModel
    # does not have any attr values for the "Complex" node and the current
    # op registry does not have have any default values for the "Complex" op.
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        ValueError,
        "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
      loader.load(sess, ["foo"], export_dir)

    # Update the Op registry to change the defaults for attr "Tout"
    # (complex64 -> complex128).
    complex_op_def.CopyFrom(original_complex_op_def)
    for attr_def in complex_op_def.attr:
      if attr_def.name == "Tout":
        attr_def.default_value.type = types_pb2.DT_COMPLEX128

    # Loading the SavedModel via the loader must set "Tout" attr_value for the
    # "Complex" node according to the latest defaults (complex128). This is
    # expected to fail the model import as there is no OpKernel registered to
    # handle attrs "T" (float32) and "Tout" (complex128).
    sess = session.Session(graph=ops.Graph())
    with self.assertRaisesRegexp(
        errors.InvalidArgumentError,
        ".*No OpKernel was registered to support Op \'Complex\' with these "
        "attrs..*"):
      loader.load(sess, ["foo"], export_dir)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:57,代码来源:saved_model_test.py


示例3: _is_array_type_input

def _is_array_type_input(op, i):
    registered_ops = op_def_registry.get_registered_ops()
    if op not in registered_ops:
        return False
    op_def = registered_ops[op]
    if i not in xrange(len(op_def.input_arg)):
        raise TypeError("Expected arg index " "to be in [0, %d)" % len(op_def.input_arg))
    input_arg = op_def.input_arg[i]
    return True if input_arg.number_attr else False
开发者ID:yuanms2,项目名称:tensorflow,代码行数:9,代码来源:function.py


示例4: _register_function_ops

def _register_function_ops(func_list):
  """Registers custom ops in the default graph. This is needed
  Because our checkpoint is saved with ops that are not part of Tensorflow."""
  op_dict = op_def_registry.get_registered_ops()
  for func in func_list:
    #pylint: disable=W0212
    func._create_definition_if_needed()
    op_def = func._definition.signature
    op_dict[op_def.name] = op_def
    RegisterShape(op_def.name)(common_shapes.unknown_shape)
开发者ID:AbhinavJain13,项目名称:seq2seq,代码行数:10,代码来源:profile.py


示例5: _add_op_node

def _add_op_node(op, func):
  """Converts an op to a function def node and add it to `func`."""
  node = function_pb2.FunctionDef.Node()
  node.op = op.type
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    op_def = getattr(op, "_sig")
  else:
    op_def = op_def_registry.get_registered_ops()[op.type]
  # pylint: enable=protected-access
  attrs = _get_node_def_attr(op)
  if not op_def.output_arg:
    node.ret.append(_make_argname_from_tensor_name(op.name))
  else:
    out_index = 0
    for arg_def in op_def.output_arg:
      if arg_def.number_attr:
        dtype = arg_def.type or attrs[arg_def.type_attr].type
        num = attrs[arg_def.number_attr].i
        node.ret.append(
            _add_output_array(op, out_index, out_index + num, dtype, func))
        out_index += num
      elif arg_def.type_list_attr:
        dtype_lst = attrs[arg_def.type_list_attr].list.type
        num = len(dtype_lst)
        node.ret.append(
            _add_output_list(op, out_index, out_index + num, dtype_lst, func))
        out_index += num
      else:
        node.ret.append(
            _make_argname_from_tensor_name(op.outputs[out_index].name))
        out_index += 1
  inp_index = 0
  for arg_def in op_def.input_arg:
    if arg_def.number_attr:
      dtype = arg_def.type or attrs[arg_def.type_attr].type
      num = attrs[arg_def.number_attr].i
      node.arg.append(
          _add_input_array(op, inp_index, inp_index + num, dtype, func))
      inp_index += num
    elif arg_def.type_list_attr:
      num = len(attrs[arg_def.type_list_attr].list.type)
      node.arg.extend([
          _make_argname_from_tensor_name(op.inputs[i].name)
          for i in range(inp_index, inp_index + num)
      ])
      inp_index += num
    else:
      node.arg.append(_make_argname_from_tensor_name(op.inputs[inp_index].name))
      inp_index += 1
  node.dep.extend(
      [_make_argname_from_tensor_name(x.name) for x in op.control_inputs])
  for k, v in _get_node_def_attr(op).items():
    node.attr[k].CopyFrom(v)
  func.node.extend([node])
开发者ID:KalraA,项目名称:tensorflow,代码行数:55,代码来源:function.py


示例6: _stripped_op_list_for_graph

def _stripped_op_list_for_graph(graph_def):
  """Returns OpDefs of ops used in graph_def."""
  op_set = set()
  registered_ops = op_def_registry.get_registered_ops()
  for n in graph_def.node:
    if n.op in registered_ops:
      op_set.add(n.op)
  for func in graph_def.library.function:
    for n in func.node:
      if n.op in registered_ops:
        op_set.add(n.op)
  return op_def_pb2.OpList(op=[registered_ops[x] for x in sorted(op_set)])
开发者ID:sherrym,项目名称:tensorflow,代码行数:12,代码来源:saver.py


示例7: list_registered_stateful_ops_without_inputs

def list_registered_stateful_ops_without_inputs():
  """Returns set of registered stateful ops that do not expect inputs.

  This list is used to identify the ops to be included in the state-graph and
  that are subsequently fed into the apply-graphs.

  Returns:
    A set of strings.
  """
  return set([
      name
      for name, op in op_def_registry.get_registered_ops().items()
      if op.is_stateful and not op.input_arg
  ])
开发者ID:nicolas-ivanov,项目名称:hub,代码行数:14,代码来源:native_module.py


示例8: _get_ref_args

  def _get_ref_args(self, node):
    """Determine whether an input of an op is ref-type.

    Args:
      node: A `NodeDef`.

    Returns:
      A list of the arg names (as strs) that are ref-type.
    """
    op_def = op_def_registry.get_registered_ops().get(node.op)
    ref_args = []
    if op_def:
      for i, output_arg in enumerate(op_def.output_arg):
        if output_arg.is_ref:
          arg_name = node.name if i == 0 else ("%s:%d" % (node.name, i))
          ref_args.append(arg_name)
    return ref_args
开发者ID:AbhinavJain13,项目名称:tensorflow,代码行数:17,代码来源:debug_graphs.py


示例9: _strip_graph_default_valued_attrs

def _strip_graph_default_valued_attrs(meta_graph_def):
  """Strips default valued attributes for node defs in given MetaGraphDef.

  This method also sets `meta_info_def.stripped_default_attrs` in the given
  `MetaGraphDef` proto to True.

  Args:
    meta_graph_def: `MetaGraphDef` protocol buffer

  Returns:
    None.
  """
  # Map function op names to their function definitions.
  op_name_to_function = {}
  for function_def in meta_graph_def.graph_def.library.function:
    op_name_to_function[function_def.signature.name] = function_def

  # Get all registered ops.
  registered_ops = op_def_registry.get_registered_ops()

  def _strip_node_default_valued_attrs(node_def):
    """Removes default valued attributes from a single node def."""
    if node_def.op in op_name_to_function or node_def.op not in registered_ops:
      return
    op_def = registered_ops[node_def.op]

    attrs_to_strip = set()
    for attr_name, attr_value in node_def.attr.items():
      if _is_default_attr_value(op_def, attr_name, attr_value):
        attrs_to_strip.add(attr_name)

    for attr in attrs_to_strip:
      del node_def.attr[attr]

  # Process all NodeDef instances in graph_def.
  for node_def in meta_graph_def.graph_def.node:
    _strip_node_default_valued_attrs(node_def)

  # Process all NodeDef instances in graph_def.library.function.
  for function_def in meta_graph_def.graph_def.library.function:
    for function_node_def in function_def.node_def:
      _strip_node_default_valued_attrs(function_node_def)

  # Tell consumers of this graph that default valued attrs have been stripped.
  meta_graph_def.meta_info_def.stripped_default_attrs = True
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:45,代码来源:meta_graph.py


示例10: register_ops_if_needed

def register_ops_if_needed(graph_ops):
  """Register graph ops absent in op_def_registry, if present in c++ registry.

  Args:
    graph_ops: set with graph op names to register.

  Raises:
    RuntimeError: if `graph_ops` contains ops that are not in either python or
      c++ registry.
  """
  missing_ops = graph_ops - set(op_def_registry.get_registered_ops().keys())

  if not missing_ops:
    return

  p_buffer = c_api.TF_GetAllOpList()
  cpp_op_list = op_def_pb2.OpList()
  cpp_op_list.ParseFromString(c_api.TF_GetBuffer(p_buffer))
  cpp_registry_ops = {op.name: op for op in cpp_op_list.op}

  missing_op_list = op_def_pb2.OpList()
  for missing_op in missing_ops:
    if missing_op not in cpp_registry_ops:
      tf.logging.info(
          "Op %s is missing from both the python and C++ registry.",
          missing_op)
    else:
      missing_op_list.op.extend([cpp_registry_ops[missing_op]])
      tf.logging.info(
          "Adding op %s from c++ registry to python registry.",
          missing_op)

  op_def_registry.register_op_list(missing_op_list)

  # Note: Only raise missing op ValueError after trying to load ops.
  # This allows the test to exercise all the calls into TensorFlow
  # without having to write a C + python test.
  if not missing_ops <= set(cpp_registry_ops.keys()):
    raise RuntimeError(
        "Graph ops missing from the python registry (%s) are also absent from "
        "the c++ registry."
        % missing_ops.difference(set(cpp_registry_ops.keys())))
开发者ID:nicolas-ivanov,项目名称:hub,代码行数:42,代码来源:native_module.py


示例11: _stripped_op_list_for_graph

def _stripped_op_list_for_graph(graph_def):
  registered_ops = op_def_registry.get_registered_ops()
  used_ops = {n.op for n in graph_def.node}
  op_list = [registered_ops[op_name] for op_name in sorted(used_ops)]
  return op_def_pb2.OpList(op=op_list)
开发者ID:CdricGmd,项目名称:tensorflow,代码行数:5,代码来源:saver.py


示例12: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided, attrs
      for ops in `graph_def` that are not in `op_dict` that have their default
      value according to `producer_op_list` will be removed. This will allow
      some more `GraphDef`s produced by later binaries to be accepted by
      earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is None:
    producer_op_dict = None
  else:
    producer_op_dict = {op.name: op for op in producer_op_list.op}

  g = ops.get_default_graph()

  # Add any functions defined in `graph_def` to `g`
  if graph_def.library and graph_def.library.function:
    # Copy op_dict so we don't clobber the original
    op_dict = copy.copy(op_dict)
    # pylint: disable=protected-access
    # Note that we do not prepend `name` to the function name. The reasoning is
    # that function names are similar to op definition names, which currently do
    # not have a scoped name or namespace scheme.
    functions = function._from_library(graph_def.library)
    for f in functions:
      g._add_function(f)
      op_dict[f.name] = f.definition.signature
    # pylint: enable=protected-access

#.........这里部分代码省略.........
开发者ID:Immexxx,项目名称:tensorflow,代码行数:101,代码来源:importer.py


示例13: _get_op_def

def _get_op_def(op):
  return op.op_def or op_def_registry.get_registered_ops()[op.type]
开发者ID:cameronphchen,项目名称:tensorflow,代码行数:2,代码来源:function.py


示例14: _get_op_def

def _get_op_def(op):
  # pylint: disable=protected-access
  if hasattr(op, "_sig"):
    return getattr(op, "_sig")
  else:
    return op_def_registry.get_registered_ops()[op.type]
开发者ID:Hwhitetooth,项目名称:tensorflow,代码行数:6,代码来源:function.py


示例15: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None, producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  graph_def = _ProcessGraphDefParam(graph_def)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  op_dict = op_def_registry.get_registered_ops()

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()

  if graph._c_graph:  # pylint: disable=protected-access
    with ops.name_scope(name, 'import', input_map.values()) as scope:
      # Save unique prefix generated by name_scope
      if scope:
        assert scope.endswith('/')
        prefix = scope[:-1]
      else:
        prefix = ''

      # Generate any input map tensors inside name scope
      input_map = _ConvertInputMapValues(name, input_map)

    scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
    options = scoped_options.options
    _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                     return_elements)

    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        with errors.raise_exception_on_not_ok_status() as status:
          results = c_api.TF_GraphImportGraphDefWithResults(
              graph._c_graph, serialized, options, status)  # pylint: disable=protected-access
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    _ProcessNewOps(graph)

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    if graph_def.library and graph_def.library.function:
      # pylint: disable=protected-access
      functions = function._from_library(graph_def.library)
      for f in functions:
        f.add_to_graph(graph)
      # pylint: enable=protected-access

    # Treat input mappings that don't appear in the graph as an error, because
#.........这里部分代码省略.........
开发者ID:andrewharp,项目名称:tensorflow,代码行数:101,代码来源:importer.py


示例16: import_graph_def

def import_graph_def(graph_def, input_map=None, return_elements=None,
                     name=None, op_dict=None):
  """Imports the TensorFlow graph in `graph_def` into the Python `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  [`Tensor`](#Tensor) and [`Operation`](#Operation) objects. See
  [`Graph.as_graph_def()`](#Graph.as_graph_def) for a way to create a
  `GraphDef` proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Defaults to `"import"`.
    op_dict: (Optional.) A dictionary mapping op type names to `OpDef` protos.
      Must contain an `OpDef` proto for each op type named in `graph_def`.
      If omitted, uses the `OpDef` protos registered in the global registry.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  # Type checks for inputs.
  if not isinstance(graph_def, graph_pb2.GraphDef):
    # `graph_def` could be a dynamically-created message, so try a duck-typed
    # approach
    try:
      old_graph_def = graph_def
      graph_def = graph_pb2.GraphDef()
      graph_def.MergeFrom(old_graph_def)
    except TypeError:
      raise TypeError('graph_def must be a GraphDef proto.')
  if input_map is None:
    input_map = {}
  else:
    if not (isinstance(input_map, dict)
            and all(isinstance(k, compat.bytes_or_text_types)
                    for k in input_map.keys())):
      raise TypeError('input_map must be a dictionary mapping strings to '
                      'Tensor objects.')
  if return_elements is not None:
    return_elements = tuple(return_elements)
    if not all(isinstance(x, compat.bytes_or_text_types)
               for x in return_elements):
      raise TypeError('return_elements must be a list of strings.')

  # Use a canonical representation for all tensor names.
  input_map = {_CanonicalInputName(k): v for k, v in input_map.items()}
  used_input_keys = set()

  name_to_op = {}

  if op_dict is None:
    op_dict = op_def_registry.get_registered_ops()

  with ops.op_scope(input_map.values(), name, 'import'):
    g = ops.get_default_graph()
    g.graph_def_versions.CopyFrom(graph_def.versions)

    with ops.name_scope('_inputs'):
      input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}

    # NOTE(mrry): We do this in two passes, because there may be a cycle in
    # `graph_def`.

    # 1. Add operations without their inputs.
    for node in graph_def.node:
      # Set any default attr values that aren't present.
      op_def = op_dict[node.op]
      for attr_def in op_def.attr:
        key = attr_def.name
        if attr_def.HasField('default_value'):
          value = node.attr[key]
          if value is None or value.WhichOneof('value') is None:
            node.attr[key].CopyFrom(attr_def.default_value)

      output_types = _OutputTypes(node, op_dict)
      name_to_op[node.name] = g.create_op(
          node.op, [], output_types, name=node.name, attrs=node.attr,
          compute_shapes=False, compute_device=False,
          op_def=op_def)

    # 2. Add inputs to the operations.
    for node in graph_def.node:
#.........这里部分代码省略.........
开发者ID:01bui,项目名称:tensorflow,代码行数:101,代码来源:importer.py


示例17: import_graph_def

def import_graph_def(graph_def,
                     input_map=None,
                     return_elements=None,
                     name=None,
                     op_dict=None,
                     producer_op_list=None):
  """Imports the graph from `graph_def` into the current default `Graph`.

  This function provides a way to import a serialized TensorFlow
  [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
  protocol buffer, and extract individual objects in the `GraphDef` as
  @{tf.Tensor} and @{tf.Operation} objects. Once extracted,
  these objects are placed into the current default `Graph`. See
  @{tf.Graph.as_graph_def} for a way to create a `GraphDef`
  proto.

  Args:
    graph_def: A `GraphDef` proto containing operations to be imported into
      the default graph.
    input_map: A dictionary mapping input names (as strings) in `graph_def`
      to `Tensor` objects. The values of the named input tensors in the
      imported graph will be re-mapped to the respective `Tensor` values.
    return_elements: A list of strings containing operation names in
      `graph_def` that will be returned as `Operation` objects; and/or
      tensor names in `graph_def` that will be returned as `Tensor` objects.
    name: (Optional.) A prefix that will be prepended to the names in
      `graph_def`. Note that this does not apply to imported function names.
      Defaults to `"import"`.
    op_dict: (Optional.) Deprecated, do not use.
    producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
      list of `OpDef`s used by the producer of the graph. If provided,
      unrecognized attrs for ops in `graph_def` that have their default value
      according to `producer_op_list` will be removed. This will allow some more
      `GraphDef`s produced by later binaries to be accepted by earlier binaries.

  Returns:
    A list of `Operation` and/or `Tensor` objects from the imported graph,
    corresponding to the names in `return_elements`.

  Raises:
    TypeError: If `graph_def` is not a `GraphDef` proto,
      `input_map` is not a dictionary mapping strings to `Tensor` objects,
      or `return_elements` is not a list of strings.
    ValueError: If `input_map`, or `return_elements` contains names that
      do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
      it refers to an unknown tensor).
  """
  op_dict = op_def_registry.get_registered_ops()

  graph_def = _ProcessGraphDefParam(graph_def, op_dict)
  input_map = _ProcessInputMapParam(input_map)
  return_elements = _ProcessReturnElementsParam(return_elements)

  if producer_op_list is not None:
    # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
    _RemoveDefaultAttrs(op_dict, producer_op_list, graph_def)

  graph = ops.get_default_graph()
  with ops.name_scope(name, 'import', input_map.values()) as scope:
    # Save unique prefix generated by name_scope
    if scope:
      assert scope.endswith('/')
      prefix = scope[:-1]
    else:
      prefix = ''

    # Generate any input map tensors inside name scope
    input_map = _ConvertInputMapValues(name, input_map)

  scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
  options = scoped_options.options
  _PopulateTFImportGraphDefOptions(options, prefix, input_map,
                                   return_elements)

  # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
  # Session.run call cannot occur between creating the TF_Operations in the
  # TF_GraphImportGraphDefWithResults call and mutating the them in
  # _ProcessNewOps.
  with graph._mutation_lock():  # pylint: disable=protected-access
    with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
      try:
        results = c_api.TF_GraphImportGraphDefWithResults(
            graph._c_graph, serialized, options)  # pylint: disable=protected-access
        results = c_api_util.ScopedTFImportGraphDefResults(results)
      except errors.InvalidArgumentError as e:
        # Convert to ValueError for backwards compatibility.
        raise ValueError(str(e))

    # Create _DefinedFunctions for any imported functions.
    #
    # We do this by creating _DefinedFunctions directly from `graph_def`, and
    # adding them to `graph`. Adding an existing function to a TF_Graph is a
    # no-op, so this only has the effect of updating the Python state (usually
    # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
    #
    # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
    # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
    # TODO(b/74620627): move this after _ProcessNewOps outside the lock once
    # _USE_C_SHAPES is removed.
    if graph_def.library and graph_def.library.function:
#.........这里部分代码省略.........
开发者ID:StephenOman,项目名称:tensorflow,代码行数:101,代码来源:importer.py


示例18: function_def_to_graph_def

def function_def_to_graph_def(fdef, input_shapes=None):
  """Convert a FunctionDef to a GraphDef.

  Steps:
  1. Creates placeholder nodes corresponding to inputs in
     `FunctionDef.signature.input_arg`.
  2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`.
  3. Renames inputs of all nodes to use the convention of GraphDef instead of
     FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming
     in FunctionDefs is different from GraphDefs.

  Args:
    fdef: FunctionDef.
    input_shapes: Optional. A list of TensorShape objects of the shapes of
      function inputs. If specified, its length must match length of
      `fdef.signature.input_arg`. If a shape is None, the corresponding input
      placeholder will have unknown shape.

  Returns:
    A tuple of (GraphDef, dict<string, string>). The dict contains a mapping
    from nested tensor names (in FunctionDef) to flattened names (in GraphDef).

  Raises:
    ValueError: If the length of input_shapes does not match the number of
      input_args or if the FunctionDef is invalid.
  """
  graph_def = graph_pb2.GraphDef()
  graph_def.versions.CopyFrom(
      versions_pb2.VersionDef(
          producer=versions.GRAPH_DEF_VERSION,
          min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER))

  if input_shapes and len(input_shapes) != len(fdef.signature.input_arg):
    raise ValueError("Length of input_shapes must match the number of " +
                     "input_args. len(input_shapes): {} len(input_arg): {}".
                     format(len(input_shapes), len(fdef.signature.input_arg)))

  # 1. Create placeholders for input nodes.
  for i, arg_def in enumerate(fdef.signature.input_arg):
    node_def = graph_def.node.add()
    node_def.name = arg_def.name
    node_def.op = "Placeholder"
    node_def.attr["dtype"].type = arg_def.type
    if input_shapes and input_shapes[i] is not None:
      node_def.attr["shape"].shape.CopyFrom(input_shapes[i].as_proto())

  # 2. Copy all body NodeDefs to the GraphDef.
  graph_def.node.extend(fdef.node_def)

  # 3. Perform the renaming.

  # Build the tensor name mapping then flatten the tensor names.
  # See comment on `FunctionDef.node_def` on how the tensor naming in
  # FunctionDefs is different from GraphDefs.
  nested_to_flat_tensor_name = {}

  for arg_def in fdef.signature.input_arg:
    nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name)

  for node_def in fdef.node_def:
    op_def = op_def_registry.get_registered_ops().get(node_def.op)
    if not op_def:
      # TODO(b/80470245): Support functions which refer other functions.
      raise NotImplementedError(
          "No op registered for {},".format(node_def.op) +
          " it may be a function. function_def_to_graph_def " +
          "currently does not support converting functions with " +
          "references to other graph functions.")

    for attr in op_def.attr:
      if attr.type in ("func", "list(func)"):
        # TODO(b/80470245): Support functions which refer other functions.
        raise NotImplementedError("Unsupported attr {} ".format(attr.name) +
                                  " with type {}".format(attr.type) +
                                  " in op {}. ".format(op_def.name) +
                                  "function_def_to_graph_def currently does " +
                                  "not support converting functions with " +
                                  "references to other graph functions.")

    # Iterate over output_args in op_def to build the map.
    # Index of the output tensor in the flattened list of *all* output
    # tensors of the op.
    flattened_index = 0
    for arg_def in op_def.output_arg:
      num_args = _get_num_args(arg_def, node_def)
      for i in range(num_args):
        # Map tensor names from "node_name:output_arg_name:index" to
        # "node_name:flattened_index".
        nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i)
        flat_name = "{}:{}".format(node_def.name, flattened_index)
        nested_to_flat_tensor_name[nested_name] = flat_name
        flattened_index += 1

  # Update inputs of all nodes in graph.
  for node_def in graph_def.node:
    for i in range(len(node_def.input)):
      node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]]

  return graph_def, nested_to_flat_tensor_name
开发者ID:LiuCKind,项目名称:tensorflow,代码行数:99,代码来源:function_def_to_graph.py


示例19: enable_jit_nonstateful

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import gradients
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
# pylint: enable=g-import-not-at-top


_REGISTERED_OPS = op_def_registry.get_registered_ops()


def enable_jit_nonstateful(node_def):
  try:
    return not _REGISTERED_OPS[node_def.op].is_stateful
  except KeyError:
    raise ValueError("Unregistered op being created: %s" % node_def)


class JITTest(test.TestCase):

  def compute(self, use_jit, compute_fn):
    random_seed.set_random_seed(1234)
    with self.test_session(graph=ops.Graph()) as sess:
      with jit.experimental_jit_scope(use_jit):
开发者ID:1000sprites,项目名称:tensorflow,代码行数:31,代码来源:jit_test.py


示例20: import_graph_def

该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python op_def_registry.register_op_list函数代码示例发布时间:2022-05-27
下一篇:
Python meta_graph.import_scoped_meta_graph函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap