• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python resource_variable_ops.is_resource_variable函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.resource_variable_ops.is_resource_variable函数的典型用法代码示例。如果您正苦于以下问题:Python is_resource_variable函数的具体用法?Python is_resource_variable怎么用?Python is_resource_variable使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了is_resource_variable函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: update

    def update(v, g):
      """Apply gradients to a replica variable."""
      assert v is not None

      try:
        # Convert the grad to Tensor or IndexedSlices if necessary.
        g = ops.convert_to_tensor_or_indexed_slices(g)
      except TypeError:
        raise TypeError("Gradient must be convertible to a Tensor"
                        " or IndexedSlices, or None: %s" % g)
      if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
        raise TypeError(
            "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
      p = _get_processor(v)

      if context.executing_eagerly() or (
          resource_variable_ops.is_resource_variable(v) and
          not v._in_graph_mode):  # pylint: disable=protected-access
        scope_name = v.name.split(":")[0]
      else:
        scope_name = v.op.name

      # device_policy is set because non-mirrored tensors will be read in
      # `update_op`. `_resource_apply_dense`, `lr_t`, `beta1_t` and `beta2_t`
      # is an example.
      with ops.name_scope("update_" + scope_name):
        return p.update_op(self, g)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:optimizer.py


示例2: __init__

  def __init__(self, var, slice_spec, name):
    self._var_device = var.device
    self._var_shape = var.shape
    if isinstance(var, ops.Tensor):
      self.handle_op = var.op.inputs[0]
      tensor = var
    elif resource_variable_ops.is_resource_variable(var):

      def _read_variable_closure(v):
        def f():
          with ops.device(v.device):
            x = v.read_value()
            # To allow variables placed on non-CPU devices to be checkpointed,
            # we copy them to CPU on the same machine first.
            with ops.device("/device:CPU:0"):
              return array_ops.identity(x)
        return f

      self.handle_op = var.handle
      tensor = _read_variable_closure(var)
    else:
      raise ValueError(
          "Saveable is neither a resource variable nor a read operation."
          " Got: %s" % repr(var))
    spec = saveable_object.SaveSpec(tensor, slice_spec, name,
                                    dtype=var.dtype, device=var.device)
    super(ResourceVariableSaveable, self).__init__(var, [spec], name)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:27,代码来源:saveable_object_util.py


示例3: _write_object_proto

def _write_object_proto(obj, proto, asset_file_def_index):
  """Saves an object into SavedObject proto."""
  if isinstance(obj, tracking.TrackableAsset):
    proto.asset.SetInParent()
    proto.asset.asset_file_def_index = asset_file_def_index[obj]
  elif resource_variable_ops.is_resource_variable(obj):
    proto.variable.SetInParent()
    if not obj.name.endswith(":0"):
      raise ValueError("Cowardly refusing to save variable %s because of"
                       " unexpected suffix which won't be restored.")
    proto.variable.name = meta_graph._op_name(obj.name)  # pylint: disable=protected-access
    proto.variable.trainable = obj.trainable
    proto.variable.dtype = obj.dtype.as_datatype_enum
    proto.variable.synchronization = obj.synchronization.value
    proto.variable.aggregation = obj.aggregation.value
    proto.variable.shape.CopyFrom(obj.shape.as_proto())
  elif isinstance(obj, def_function.Function):
    proto.function.CopyFrom(
        function_serialization.serialize_function(obj))
  elif isinstance(obj, defun.ConcreteFunction):
    proto.bare_concrete_function.CopyFrom(
        function_serialization.serialize_bare_concrete_function(obj))
  elif isinstance(obj, _CapturedConstant):
    proto.constant.operation = obj.graph_tensor.op.name
  elif isinstance(obj, tracking.CapturableResource):
    proto.resource.device = obj._resource_device  # pylint: disable=protected-access
  else:
    registered_type_proto = revived_types.serialize(obj)
    if registered_type_proto is None:
      # Fallback for types with no matching registration
      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
          identifier="_generic_user_object",
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]))
    proto.user_object.CopyFrom(registered_type_proto)
开发者ID:aritratony,项目名称:tensorflow,代码行数:35,代码来源:save.py


示例4: _write_object_proto

def _write_object_proto(obj, proto, asset_file_def_index, node_ids):
  """Saves an object into SavedObject proto."""
  if isinstance(obj, tracking.TrackableAsset):
    proto.asset.SetInParent()
    proto.asset.asset_file_def_index = asset_file_def_index[obj]
  elif resource_variable_ops.is_resource_variable(obj):
    proto.variable.SetInParent()
    proto.variable.trainable = obj.trainable
    proto.variable.dtype = obj.dtype.as_datatype_enum
    proto.variable.shape.CopyFrom(obj.shape.as_proto())
  elif isinstance(obj, def_function.Function):
    proto.function.CopyFrom(
        function_serialization.serialize_function(obj, node_ids))
  elif isinstance(obj, defun.ConcreteFunction):
    proto.concrete_function.CopyFrom(
        function_serialization.serialize_concrete_function(obj, node_ids))
  else:
    registered_type_proto = revived_types.serialize(obj)
    if registered_type_proto is None:
      # Fallback for types with no matching registration
      registered_type_proto = saved_object_graph_pb2.SavedUserObject(
          identifier="_generic_user_object",
          version=versions_pb2.VersionDef(
              producer=1, min_consumer=1, bad_consumers=[]))
    proto.user_object.CopyFrom(registered_type_proto)
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:25,代码来源:save.py


示例5: _create_non_slot_variable

  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    """Add an extra variable, not associated with a slot."""
    # Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
    eager = context.executing_eagerly()
    graph = None if eager else colocate_with.graph

    key = (name, graph)
    v = self._non_slot_dict.get(key, None)
    if v is None:
      self._maybe_initialize_trackable()
      distribution_strategy = distribute_ctx.get_strategy()
      with distribution_strategy.extended.colocate_vars_with(colocate_with):
        if eager:
          restored_initial_value = self._preload_simple_restoration(
              name=name, shape=None)
          if restored_initial_value is not None:
            initial_value = restored_initial_value
        v = variable_scope.variable(
            initial_value, name=name, trainable=False,
            use_resource=resource_variable_ops.is_resource_variable(
                colocate_with))
      # Restore this variable by name if necessary, but don't add a
      # Trackable dependency. Optimizers return the current graph's
      # non-slot variables from _checkpoint_dependencies explicitly rather
      # than unconditionally adding dependencies (since there may be multiple
      # non-slot variables with the same name in different graphs, trying to
      # save all of them would result in errors).
      self._handle_deferred_dependencies(name=name, trackable=v)
      self._non_slot_dict[key] = v

    return v
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:31,代码来源:optimizer.py


示例6: _create_slot_var

def _create_slot_var(primary, val, scope, validate_shape, shape, dtype):
  """Helper function for creating a slot variable."""

  # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
  # scope.
  current_partitioner = variable_scope.get_variable_scope().partitioner
  variable_scope.get_variable_scope().set_partitioner(None)
  # When init from val instead of callable initializer, the shape is expected to
  # be None, not <unknown> or any fully defined shape.
  shape = shape if callable(val) else None
  slot = variable_scope.get_variable(
      scope, initializer=val, trainable=False,
      use_resource=resource_variable_ops.is_resource_variable(primary),
      shape=shape, dtype=dtype,
      validate_shape=validate_shape)
  variable_scope.get_variable_scope().set_partitioner(current_partitioner)

  # pylint: disable=protected-access
  if isinstance(primary, variables.Variable) and primary._save_slice_info:
    # Primary is a partitioned variable, so we need to also indicate that
    # the slot is a partitioned variable.  Slots have the same partitioning
    # as their primaries.
    # For examples when using AdamOptimizer in linear model, slot.name
    # here can be "linear//weights/Adam:0", while primary.op.name is
    # "linear//weight". We want to get 'Adam' as real_slot_name, so we
    # remove "'linear//weight' + '/'" and ':0'.
    real_slot_name = slot.name[len(primary.op.name + "/"):-2]
    slice_info = primary._save_slice_info
    slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
        slice_info.full_name + "/" + real_slot_name,
        slice_info.full_shape[:],
        slice_info.var_offset[:],
        slice_info.var_shape[:]))
  # pylint: enable=protected-access
  return slot
开发者ID:AnishShah,项目名称:tensorflow,代码行数:35,代码来源:slot_creator.py


示例7: _write_object_graph

def _write_object_graph(saveable_view, export_dir, asset_file_def_index):
  """Save a SavedObjectGraph proto for `root`."""
  # SavedObjectGraph is similar to the CheckpointableObjectGraph proto in the
  # checkpoint. It will eventually go into the SavedModel.
  proto = saved_object_graph_pb2.SavedObjectGraph()
  saveable_view.fill_object_graph_proto(proto)

  node_ids = util.ObjectIdentityDictionary()
  for i, obj in enumerate(saveable_view.nodes):
    node_ids[obj] = i
    if resource_variable_ops.is_resource_variable(obj):
      node_ids[obj.handle] = i
    elif isinstance(obj, tracking.TrackableAsset):
      node_ids[obj.asset_path.handle] = i

  for obj, obj_proto in zip(saveable_view.nodes, proto.nodes):
    _write_object_proto(obj, obj_proto, asset_file_def_index, node_ids)

  extra_asset_dir = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.EXTRA_ASSETS_DIRECTORY))
  file_io.recursive_create_dir(extra_asset_dir)
  object_graph_filename = os.path.join(
      extra_asset_dir, compat.as_bytes("object_graph.pb"))
  file_io.write_string_to_file(object_graph_filename, proto.SerializeToString())
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:25,代码来源:save.py


示例8: _get_tensor_from_node

 def _get_tensor_from_node(self, node_id):
   obj = self._nodes[node_id]
   if resource_variable_ops.is_resource_variable(obj):
     return obj.handle
   elif isinstance(obj, tracking.TrackableAsset):
     return obj.asset_path.handle
   raise ValueError("Can't convert node %s to tensor" % (type(obj)))
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:7,代码来源:load.py


示例9: _map_resources

def _map_resources(accessible_objects):
  """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
  """
  # TODO(allenl, rohanj): Map generic resources rather than just variables.
  # TODO(allenl): Handle MirroredVariables and other types of variables which
  # may need special casing.
  object_map = {}
  resource_map = {}
  for obj in accessible_objects:
    if resource_variable_ops.is_resource_variable(obj):
      new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
      object_map[obj] = new_variable
      resource_map[obj.handle] = new_variable.handle
  return object_map, resource_map
开发者ID:bunbutter,项目名称:tensorflow,代码行数:30,代码来源:save.py


示例10: _restore_checkpoint

  def _restore_checkpoint(self):
    """Load state from checkpoint into the deserialized objects."""
    variables_path = saved_model_utils.get_variables_path(self._export_dir)
    # TODO(andresp): Clean use of private methods of TrackableSaver.
    # pylint: disable=protected-access
    saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
    with ops.device("CPU"):
      saver._file_prefix_placeholder = constant_op.constant(variables_path)
    load_status = saver.restore(variables_path)
    load_status.assert_existing_objects_matched()
    checkpoint = load_status._checkpoint

    # When running in eager mode, the `restore` call above has already run and
    # restored the state of trackables, call `position.restore_ops()` will
    # return an empty list as there is nothing left to do. In graph mode, that
    # will return the list of ops that must run to restore the object on that
    # position. We have to wire them in the initializers of the objects so that
    # they get initialized properly when using common practices (e.g. the ones
    # used by ManagedSession) without further user action.
    for object_id, obj in dict(checkpoint.object_by_proto_id).items():
      position = base.CheckpointPosition(checkpoint=checkpoint,
                                         proto_id=object_id)
      restore_ops = position.restore_ops()
      if restore_ops:
        if resource_variable_ops.is_resource_variable(obj):
          obj._initializer_op = restore_ops
        else:
          raise NotImplementedError(
              ("Missing functionality to restore state of object "
               "%r from the checkpoint." % obj))
开发者ID:aritratony,项目名称:tensorflow,代码行数:30,代码来源:load.py


示例11: map_resources

  def map_resources(self):
    """Makes new resource handle ops corresponding to existing resource tensors.

    Creates resource handle ops in the current default graph, whereas
    `accessible_objects` will be from an eager context. Resource mapping adds
    resource handle ops to the main GraphDef of a SavedModel, which allows the
    C++ loader API to interact with variables.

    Returns:
      A tuple of (object_map, resource_map, asset_info):
        object_map: A dictionary mapping from object in `accessible_objects` to
          replacement objects created to hold the new resource tensors.
        resource_map: A dictionary mapping from resource tensors extracted from
          `accessible_objects` to newly created resource tensors.
        asset_info: An _AssetInfo tuple describing external assets referenced
          from accessible_objects.
    """
    # Only makes sense when adding to the export Graph
    assert not context.executing_eagerly()
    # TODO(allenl): Handle MirroredVariables and other types of variables which
    # may need special casing.
    object_map = object_identity.ObjectIdentityDictionary()
    resource_map = {}
    asset_info = _AssetInfo(
        asset_defs=[],
        asset_initializers_by_resource={},
        asset_filename_map={},
        asset_index={})
    for node_id, obj in enumerate(self.nodes):
      if isinstance(obj, tracking.TrackableResource):
        new_resource = obj._create_resource()  # pylint: disable=protected-access
        resource_map[obj.resource_handle] = new_resource
        self.captured_tensor_node_ids[obj.resource_handle] = node_id
      elif resource_variable_ops.is_resource_variable(obj):
        new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
        object_map[obj] = new_variable
        resource_map[obj.handle] = new_variable.handle
        self.captured_tensor_node_ids[obj.handle] = node_id
      elif isinstance(obj, tracking.TrackableAsset):
        _process_asset(obj, asset_info, resource_map)
        self.captured_tensor_node_ids[obj.asset_path] = node_id

    for concrete_function in self.concrete_functions:
      for capture in concrete_function.captured_inputs:
        if (tensor_util.is_tensor(capture)
            and capture.dtype not in _UNCOPIABLE_DTYPES
            and capture not in self.captured_tensor_node_ids):
          copied_tensor = constant_op.constant(
              tensor_util.constant_value(capture))
          node_id = len(self.nodes)
          node = _CapturedConstant(
              eager_tensor=capture, graph_tensor=copied_tensor)
          self.nodes.append(node)
          self.node_ids[capture] = node_id
          self.node_ids[node] = node_id
          self.captured_tensor_node_ids[capture] = node_id
          resource_map[capture] = copied_tensor

    return object_map, resource_map, asset_info
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:59,代码来源:save.py


示例12: _get_tensor_from_node

 def _get_tensor_from_node(self, node_id):
   """Resolves a node id into a tensor to be captured for a function."""
   with ops.init_scope():
     obj = self._nodes[node_id]
     if resource_variable_ops.is_resource_variable(obj):
       return obj.handle
     elif isinstance(obj, tracking.TrackableAsset):
       return obj.asset_path
     elif tensor_util.is_tensor(obj):
       return obj
     elif isinstance(obj, tracking.CapturableResource):
       # Note: this executes restored functions in the CapturableResource.
       return obj.resource_handle
     raise ValueError("Can't convert node %s to tensor" % (type(obj)))
开发者ID:aritratony,项目名称:tensorflow,代码行数:14,代码来源:load.py


示例13: _get_processor

def _get_processor(v):
  """The processor of v."""
  if context.executing_eagerly():
    if isinstance(v, ops.Tensor):
      return _TensorProcessor(v)
    else:
      return _DenseResourceVariableProcessor(v)
  if resource_variable_ops.is_resource_variable(v) and not v._in_graph_mode:  # pylint: disable=protected-access
    # True if and only if `v` was initialized eagerly.
    return _DenseResourceVariableProcessor(v)
  if v.op.type == "VarHandleOp":
    return _DenseResourceVariableProcessor(v)
  if isinstance(v, variables.Variable):
    return _RefVariableProcessor(v)
  if isinstance(v, ops.Tensor):
    return _TensorProcessor(v)
  raise NotImplementedError("Trying to optimize unsupported type ", v)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:17,代码来源:optimizer.py


示例14: _get_expanded_variable_list

def _get_expanded_variable_list(var_list):
  """Given a list of variables, expands them if they are partitioned.

  Args:
    var_list: A list of variables.

  Returns:
    A list of variables where each partitioned variable is expanded to its
    components.
  """
  returned_list = []
  for variable in var_list:
    if (isinstance(variable, variable_ops.Variable) or
        resource_variable_ops.is_resource_variable(variable)):
      returned_list.append(variable)  # Single variable case.
    else:  # Must be a PartitionedVariable, so convert into a list.
      returned_list.extend(list(variable))
  return returned_list
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:18,代码来源:linear.py


示例15: zero_initializer

def zero_initializer(ref, use_locking=True, name="zero_initializer"):
  """Initialize 'ref' with all zeros, ref tensor should be uninitialized.
  If already initialized, you will get ValueError. This op is intended to
  save memory during initialization.
  Args:
    ref: ref of the tensor need to be zero initialized.
    name: optional name for this operation.
  Returns:
    ref that initialized.
  Raises:
    ValueError: If ref tensor is initialized.
  """
  loader.load_op_library(
      resource_loader.get_path_to_datafile("_variable_ops.so"))
  if resource_variable_ops.is_resource_variable(ref):
    return gen_variable_ops.zero_var_initializer(
        ref.handle, shape=ref.shape, dtype=ref.dtype, name=name)
  else:
    return gen_variable_ops.zero_initializer(ref, name=name)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:19,代码来源:variables.py


示例16: assert_global_step

def assert_global_step(global_step_tensor):
  """Asserts `global_step_tensor` is a scalar int `Variable` or `Tensor`.

  Args:
    global_step_tensor: `Tensor` to test.
  """
  if not (isinstance(global_step_tensor, variables.Variable) or
          isinstance(global_step_tensor, ops.Tensor) or
          resource_variable_ops.is_resource_variable(global_step_tensor)):
    raise TypeError('Existing "global_step" must be a Variable or Tensor: %s.' %
                    global_step_tensor)

  if not global_step_tensor.dtype.base_dtype.is_integer:
    raise TypeError('Existing "global_step" does not have integer type: %s' %
                    global_step_tensor.dtype)

  if (global_step_tensor.get_shape().ndims != 0 and
      global_step_tensor.get_shape().is_fully_defined()):
    raise TypeError('Existing "global_step" is not scalar: %s' %
                    global_step_tensor.get_shape())
开发者ID:aritratony,项目名称:tensorflow,代码行数:20,代码来源:training_util.py


示例17: __init__

  def __init__(self, variable):
    """Creates an AutoCastVariable instance.

    Args:
      variable: A floating-point resource variable to wrap.

    Raises:
      ValueError: If `variable` is not a floating-point resource variable
    """
    if not resource_variable_ops.is_resource_variable(variable):
      raise ValueError('variable must be of type tf.ResourceVariable, but got: '
                       '%s' % variable)
    if not variable.dtype.is_floating:
      raise ValueError('variable must be a floating point variable but has '
                       'type: %s' % variable.dtype.name)
    self._variable = variable

    # Delegate to the underlying variable for checkpointing.
    self._gather_saveables_for_checkpoint = (
        self._variable._gather_saveables_for_checkpoint)  # pylint: disable=protected-access
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:20,代码来源:autocast_variable.py


示例18: _map_resources

def _map_resources(accessible_objects):
  """Makes new resource handle ops corresponding to existing resource tensors.

  Creates resource handle ops in the current default graph, whereas
  `accessible_objects` will be from an eager context. Resource mapping adds
  resource handle ops to the main GraphDef of a SavedModel, which allows the C++
  loader API to interact with variables.

  Args:
    accessible_objects: A list of objects, some of which may contain resources,
      to create replacements for.

  Returns:
    A tuple of (object_map, resource_map, asset_info):
      object_map: A dictionary mapping from object in `accessible_objects` to
        replacement objects created to hold the new resource tensors.
      resource_map: A dictionary mapping from resource tensors extracted from
        `accessible_objects` to newly created resource tensors.
      asset_info: An _AssetInfo tuple describing external assets referenced from
        accessible_objects.
  """
  # TODO(allenl): Handle MirroredVariables and other types of variables which
  # may need special casing.
  object_map = util.ObjectIdentityDictionary()
  resource_map = {}
  asset_info = _AssetInfo(
      asset_defs=[],
      asset_initializers_by_resource={},
      asset_filename_map={},
      asset_index={})
  for obj in accessible_objects:
    if isinstance(obj, tracking.TrackableResource):
      new_resource = obj.create_resource()
      resource_map[obj.resource_handle] = new_resource
    elif resource_variable_ops.is_resource_variable(obj):
      new_variable = resource_variable_ops.copy_to_graph_uninitialized(obj)
      object_map[obj] = new_variable
      resource_map[obj.handle] = new_variable.handle
    elif isinstance(obj, tracking.TrackableAsset):
      _process_asset(obj, asset_info, resource_map)
  return object_map, resource_map, asset_info
开发者ID:rmlarsen,项目名称:tensorflow,代码行数:41,代码来源:save.py


示例19: _format_device

def _format_device(var):
  """Returns the device with an annotation specifying `ResourceVariable`.

  "legacy" means a normal tf.Variable while "resource" means a ResourceVariable.

  For example:
  `(legacy)`
  `(resource)`
  `/job:learner/task:0/device:CPU:* (legacy)`
  `/job:learner/task:0/device:CPU:* (resource)`

  Args:
    var: The Tensorflow Variable or `ResourceVariable` to print.
  """
  if resource_variable_ops.is_resource_variable(var):
    resource_var_annotation = "(resource)"
  else:
    resource_var_annotation = "(legacy)"

  if var.device:
    return "{} {}".format(var.device, resource_var_annotation)
  else:
    return resource_var_annotation
开发者ID:ccchang0111,项目名称:sonnet,代码行数:23,代码来源:util.py


示例20: _graph_mode_decorator

def _graph_mode_decorator(f, *args, **kwargs):
  """Implement custom gradient decorator for graph mode."""
  # TODO(rsepassi): Add support for kwargs
  if kwargs:
    raise ValueError(
        "The custom_gradient decorator currently supports keywords "
        "arguments only when eager execution is enabled.")
  name = "CustomGradient-%s" % ops.uid()
  args = [ops.convert_to_tensor(x) for x in args]

  # Checking global and local variables attempts to ensure that no non-resource
  # Variables are added to the graph.
  current_var_scope = variable_scope.get_variable_scope()
  before_vars = set(current_var_scope.global_variables() +
                    current_var_scope.local_variables())
  with backprop.GradientTape() as tape:
    result, grad_fn = f(*args)
  after_vars = set(current_var_scope.global_variables() +
                   current_var_scope.local_variables())
  new_vars = after_vars - before_vars
  for v in new_vars:
    if not resource_variable_ops.is_resource_variable(v):
      raise TypeError(
          "All variables used by a function wrapped with @custom_gradient must "
          "be `ResourceVariable`s. Ensure that no `variable_scope` is created "
          "with `use_resource=False`.")
  # The variables that grad_fn needs to return gradients for are the set of
  # variables used that are *not* part of the inputs.
  variables = list(set(tape.watched_variables()) - set(args))
  grad_argspec = tf_inspect.getfullargspec(grad_fn)
  variables_in_signature = ("variables" in grad_argspec.args or
                            grad_argspec.varkw)
  if variables and not variables_in_signature:
    raise TypeError("If using @custom_gradient with a function that "
                    "uses variables, then grad_fn must accept a keyword "
                    "argument 'variables'.")
  if variables_in_signature and not variables:
    # User seems to intend to use variables but none were captured.
    if not variable_scope.get_variable_scope().use_resource:
      raise TypeError("If using @custom_gradient with a function that "
                      "uses variables, the enclosing variable scope must "
                      "have use_resource=True.")
    else:
      logging.warn("@custom_gradient grad_fn has 'variables' in signature, but "
                   "no ResourceVariables were used on the forward pass.")
  flat_result = nest.flatten(result)
  all_tensors = flat_result + args + variables

  def tape_grad_fn(*result_grads):
    """Custom grad fn wrapper."""
    result_grads = result_grads[:len(flat_result)]
    if variables:
      input_grads, variable_grads = grad_fn(*result_grads, variables=variables)
      if len(variable_grads) != len(variables):
        raise ValueError("Must return gradient for each variable from "
                         "@custom_gradient grad_fn.")
    else:
      input_grads = grad_fn(*result_grads)
      variable_grads = []

    # Need to return one value per input to the IdentityN, so pad the
    # gradients of the inputs of the custom_gradient function with the
    # gradients of the outputs as well.
    input_grads = nest.flatten(input_grads)
    return ([None] * len(flat_result)) + input_grads + variable_grads

  @ops.RegisterGradient(name)
  def internal_grad_fn(unused_op, *result_grads):  # pylint: disable=unused-variable
    """Custom grad fn wrapper."""
    return tape_grad_fn(*result_grads)

  original_tensors = all_tensors
  with ops.get_default_graph().gradient_override_map({"IdentityN": name}):
    all_tensors = array_ops.identity_n(all_tensors)

  original_tensors = [ops.convert_to_tensor(x) for x in original_tensors]

  # Propagate handle data for happier shape inference for resource variables.
  for i, t in enumerate(original_tensors):
    if t.dtype == dtypes.resource and hasattr(t, "_handle_data"):
      all_tensors[i]._handle_data = t._handle_data  # pylint: disable=protected-access
  tape_lib.record_operation(
      f.__name__, all_tensors, original_tensors, tape_grad_fn)
  for ot, t in zip(original_tensors, all_tensors):
    copy_handle_data(ot, t)
  return nest.pack_sequence_as(
      structure=result, flat_sequence=all_tensors[:len(flat_result)])
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:87,代码来源:custom_gradient.py



注:本文中的tensorflow.python.ops.resource_variable_ops.is_resource_variable函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap