• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python ops.register_tensor_conversion_function函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.ops.register_tensor_conversion_function函数的典型用法代码示例。如果您正苦于以下问题:Python register_tensor_conversion_function函数的具体用法?Python register_tensor_conversion_function怎么用?Python register_tensor_conversion_function使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了register_tensor_conversion_function函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test_works_with_registered

  def test_works_with_registered(self):

    class CustomClass(object):

      def value(self):
        return ops.convert_to_tensor(42.)

    ops.register_tensor_conversion_function(
        CustomClass, lambda value, **_: value.value())

    tf_utils.register_symbolic_tensor_type(CustomClass)

    if context.executing_eagerly():
      self.assertFalse(tf_utils.is_symbolic_tensor(
          variables.Variable(name='blah', initial_value=0.)))
      self.assertFalse(tf_utils.is_symbolic_tensor(
          ops.convert_to_tensor(0.)))
      self.assertFalse(tf_utils.is_symbolic_tensor(
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
      self.assertFalse(tf_utils.is_symbolic_tensor(CustomClass()))
    else:
      self.assertTrue(tf_utils.is_symbolic_tensor(
          variables.Variable(name='blah', initial_value=0.)))
      self.assertTrue(tf_utils.is_symbolic_tensor(
          ops.convert_to_tensor(0.)))
      self.assertTrue(tf_utils.is_symbolic_tensor(
          sparse_tensor.SparseTensor(
              indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])))
      self.assertTrue(tf_utils.is_symbolic_tensor(CustomClass()))
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:30,代码来源:tf_utils_test.py


示例2: test_enables_nontensor_plumbing

  def test_enables_nontensor_plumbing(self):
    # Setup.

    class Foo(object):

      def __init__(self, input_):
        self._input = input_
        self.value = ops.convert_to_tensor(42.)

    ops.register_tensor_conversion_function(
        Foo, lambda x, *args, **kwargs: x.value)
    tf_utils.register_symbolic_tensor_type(Foo)

    class PlumbingLayer(keras.layers.Lambda):

      def __init__(self, fn, **kwargs):
        def _fn(*fargs, **fkwargs):
          d = fn(*fargs, **fkwargs)
          x = ops.convert_to_tensor(d)
          d.shape = x.shape
          d.get_shape = x.get_shape
          return d, x
        super(PlumbingLayer, self).__init__(_fn, **kwargs)
        self._enter_dunder_call = False

      def __call__(self, inputs, *args, **kwargs):
        self._enter_dunder_call = True
        d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs)
        self._enter_dunder_call = False
        return d

      def call(self, inputs, *args, **kwargs):
        d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
        if self._enter_dunder_call:
          return d, v
        return d

    # User-land.
    model = keras.Sequential([
        keras.layers.InputLayer([]),
        PlumbingLayer(Foo),  # Makes a `Foo` object.
    ])
    # Let's ensure Keras graph history is preserved by composing the models.
    model = keras.Model(model.inputs, model(model.outputs))
    # Now we instantiate the model and verify we have a `Foo` object, not a
    # `Tensor`.
    y = model(ops.convert_to_tensor(7.))
    self.assertIsInstance(y, Foo)
开发者ID:aeverall,项目名称:tensorflow,代码行数:48,代码来源:tf_utils_test.py


示例3: testFullDelegationControlUsingRegistry

  def testFullDelegationControlUsingRegistry(self):

    class NumpyArraySubclass(np.ndarray):

      def __radd__(self, lhs):
        return "Works!"

    def raise_to_delegate(value, dtype=None, name=None, as_ref=False):
      del value, dtype, name, as_ref  # Unused.
      raise TypeError

    ops.register_tensor_conversion_function(
        NumpyArraySubclass, raise_to_delegate, priority=0)
    tensor = ops.convert_to_tensor([[10.0, 20.0]])
    rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0]))
    res = tensor + rhs
    self.assertEqual(res, "Works!")
开发者ID:1000sprites,项目名称:tensorflow,代码行数:17,代码来源:tensor_priority_test.py


示例4: _register_variable_read

     collections: any collections in which this operation should be inserted.
     trainable: whether this read is to be used for training.

    Returns:
     the read operation.
    """
    with ops.name_scope("Read"):
      value = gen_resource_variable_ops.read_variable_op(
          self._handle, dtype=self._dtype)
    _register_variable_read(value, collections=collections, trainable=trainable)
    return value

  def sparse_read(self, indices, collections=None, trainable=True, name=None):
    with ops.name_scope("Gather" if name is None else name):
      value = gen_resource_variable_ops.resource_gather(
          self._handle, indices, dtype=self._dtype)
    _register_variable_read(value, collections=collections, trainable=trainable)
    return value


# pylint: disable=unused-argument
def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
  if dtype is not None and dtype != var.value.dtype:
    print("trying to switch the dtype to ", dtype, " from ", var.value.dtype)
    return NotImplemented
  return var.value

# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
开发者ID:HKUST-SING,项目名称:tensorflow,代码行数:30,代码来源:resource_variable_ops.py


示例5:

            # Return an empty tensor so we only need to check for returned tensor
            # size being 0 as an indication of model ready.
            return array_ops.constant([], dtype=dtypes.string)
        else:
            # Get a 1-D boolean tensor listing whether each variable is initialized.
            variables_mask = math_ops.logical_not(
                array_ops.pack([state_ops.is_variable_initialized(v) for v in var_list])
            )
            # Get a 1-D string tensor containing all the variable names.
            variable_names_tensor = array_ops.constant([s.op.name for s in var_list])
            # Return a 1-D tensor containing all the names of uninitialized variables.
            return array_ops.boolean_mask(variable_names_tensor, variables_mask)


# pylint: disable=protected-access
ops.register_tensor_conversion_function(Variable, Variable._TensorConversionFunction)
Variable._OverloadAllOperators()

ops.register_tensor_conversion_function(PartitionedVariable, PartitionedVariable._TensorConversionFunction)
# pylint: enable=protected-access

ops.register_dense_tensor_like_type(Variable)
ops.register_proto_function(
    ops.GraphKeys.GLOBAL_VARIABLES,
    proto_type=variable_pb2.VariableDef,
    to_proto=Variable.to_proto,
    from_proto=Variable.from_proto,
)
ops.register_proto_function(
    ops.GraphKeys.TRAINABLE_VARIABLES,
    proto_type=variable_pb2.VariableDef,
开发者ID:shakamunyi,项目名称:tensorflow,代码行数:31,代码来源:variables.py


示例6: _should_act_as_resource_variable

          'of type {!r}'.format(dtype.name, self.dtype.name))
    val = ops.internal_convert_to_tensor(self._variable,
                                         self._variable.dtype, name,
                                         as_ref=False)
    with ops.colocate_with(None, ignore_existing=True):
      with ops.device(val.device):
        return math_ops.cast(val, self.dtype)

  def _should_act_as_resource_variable(self):
    """Pass resource_variable_ops.is_resource_variable check."""
    pass

  # TODO(reedwm): Define operator overloads.


ops.register_tensor_conversion_function(
    AutoCastVariable, AutoCastVariable._dense_var_to_tensor)  # pylint:disable=protected-access
ops.register_dense_tensor_like_type(AutoCastVariable)


# We have DistributedVariable subclass to pass
# isinstance(..., DistributedVariable) checks when wrapping a
# DistributedVariable.
# TODO(reedwm): We should not wrap DistributedVariable, but instead have
# DistributedVariable wrap AutoCastVariable. Subclassing DistributedVariable is
# messy, because we do not fully implement the interface of DistributedVariable.
class AutoCastDistributedVariable(AutoCastVariable,
                                  distribute_values.DistributedVariable):
  """Version of AutoCastVariable that subclasses DistributedVariable."""

  def __init__(self, variable):
    if not isinstance(variable, distribute_values.DistributedValues):
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:32,代码来源:autocast_variable.py


示例7: devices

  def devices(self):
    return set(tensor.device for tensor in self.tensors)

  def __str__(self):
    return "PartitionedTensor([%s, ...], dtype=%s, shape=%s)" % (
        self.tensors[0].name, self.dtype.name, tuple(self.shape.as_list()))

  def __hash__(self):
    return hash(tuple(self.tensors))

  def as_tensor(self, dtype=None, name=None, as_ref=False):
    with ops.name_scope(name, "PartitionedTensor.as_tensor", self.tensors):
      assert not as_ref
      assert dtype in [None, self.dtype]
      result = array_ops.concat(self.tensors, axis=0)

      # Cache 'result' if we haven't already cached a value for this device.
      if result.device not in self._concats:
        self._concats[result.device] = result
      return self._concats[result.device]


ops.register_tensor_conversion_function(
    PartitionedTensor,
    lambda val, dtype, name, as_ref: val.as_tensor(dtype, name, as_ref))


# TODO(b/69623235): Add a function for finding tensors that share gradients
# to eliminate redundant fisher factor computations.
开发者ID:DILASSS,项目名称:tensorflow,代码行数:29,代码来源:utils.py


示例8: _ConstantShape

  tensor_value.tensor.CopyFrom(
      tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape))
  dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype)
  const_tensor = g.create_op(
      "Const", [], [dtype_value.type],
      attrs={"value": tensor_value, "dtype": dtype_value}, name=name).outputs[0]
  return const_tensor


@ops.RegisterShape("Const")
def _ConstantShape(op):
  return [tensor_shape.TensorShape(
      [d.size for d in op.get_attr("value").tensor_shape.dim])]


ops.register_tensor_conversion_function((list, tuple), constant, 100)
ops.register_tensor_conversion_function(np.ndarray, constant, 100)
ops.register_tensor_conversion_function(np.generic, constant, 100)
ops.register_tensor_conversion_function(object, constant, 200)

def _tensor_shape_tensor_conversion_function(s, dtype=None, name=None):
  if not s.is_fully_defined():
    raise ValueError(
        "Cannot convert a partially known TensorShape to a Tensor: %s" % s)
  if dtype is not None:
    if dtype not in (types.int32, types.int64):
      raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype)
  else:
    dtype = types.int32
  if name is None:
    name = "shape_as_tensor"
开发者ID:bradg19,项目名称:tensor,代码行数:31,代码来源:constant_op.py


示例9: _convert_labeled_tensor_to_tensor

tc.register_type_abbreviation(ops.Tensor, 'tensorflow.Tensor')
tc.register_type_abbreviation(dtypes.DType, 'tensorflow.DType')
# core LabeledTensor types
tc.register_type_abbreviation(Axis, 'labeled_tensor.Axis')
tc.register_type_abbreviation(Axes, 'labeled_tensor.Axes')
tc.register_type_abbreviation(LabeledTensor, 'labeled_tensor.LabeledTensor')


@tc.returns(ops.Tensor)
@tc.accepts(LabeledTensor)
def _convert_labeled_tensor_to_tensor(value, *args, **kwargs):
  # call ops.convert_to_tensor to handle optional arguments appropriately
  return ops.convert_to_tensor(value.tensor, *args, **kwargs)


ops.register_tensor_conversion_function(
    LabeledTensor, _convert_labeled_tensor_to_tensor)


# tc class for anything that can be coerced into a LabeledTensor
# pylint: disable=invalid-name
LabeledTensorLike = tc.Union(LabeledTensor, ops.Tensor, np.ndarray, Scalar)
# pylint: enable=invalid-name


@tc.returns(LabeledTensor)
@tc.accepts(LabeledTensorLike, object, tc.Optional(string_types))
def convert_to_labeled_tensor(value, dtype=None, name=None):
  """Converts the given `value` to a `LabeledTensor`.

  This function accepts `LabeledTensor` objects, 0-dimensional `Tensor` objects
  and numpy arrays, and Python scalars. Higher dimensional unlabeled tensors
开发者ID:HKUST-SING,项目名称:tensorflow,代码行数:32,代码来源:core.py


示例10: _saveable_factory

    def _saveable_factory(name=self._common_name):
      return _MirroredSaveable(self, self._primary_var, name)
    return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}


# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion_mirrored(var, dtype=None, name=None, as_ref=False):
  # Try to avoid assignments to and other mutations of MirroredVariable
  # state except through a DistributionStrategy.update() call.
  assert not as_ref
  return ops.internal_convert_to_tensor(
      var.get(), dtype=dtype, name=name, as_ref=as_ref)


ops.register_tensor_conversion_function(MirroredVariable,
                                        _tensor_conversion_mirrored)


class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
  """Class for defining how to restore a TowerLocalVariable."""

  def __init__(self, tower_local_variable, name):
    self._tower_local_variable = tower_local_variable
    # We use a callable so that we don't have to evaluate this expression
    # in the case where we are trying to restore instead of save.
    def tensor():
      return distribute_lib.get_distribution_strategy().read_var(
          tower_local_variable)
    spec = saver.BaseSaverBuilder.SaveSpec(
        tensor=tensor,
        slice_spec="",
开发者ID:sonnyhu,项目名称:tensorflow,代码行数:32,代码来源:values.py


示例11: graph_placeholder

  captured_value = tensor_map.get(ops.tensor_id(value), None)
  if captured_value is None:
    captured_value = graph_placeholder(
        dtype=dtype or value.dtype, shape=value.shape, name=name)
    if captured_value.dtype == dtypes.resource:
      captured_value._handle_data = value._handle_data  # pylint: disable=protected-access
    tensor_map[ops.tensor_id(value)] = (value, captured_value)
  else:
    captured_value = captured_value[1]
  return captured_value


# TODO(apassos): it'd be really nice if we could scope this registration.
# Note that we register this at a higher priority than ops.Tensor since we want
# to handle subclass specific conversion before a superclass conversion.
ops.register_tensor_conversion_function(
    tensor.Tensor, _convert_to_graph_constant, priority=-1)


class _CapturingContext(object):
  """Tracks references to Tensors outside this context while it is active."""

  def __init__(self):
    # known_ops are ops which are created while this context is active
    self.known_ops = set()

    # captured_tensors are all tensors referenced to by ops in this context but
    # not produced in it
    self.captured_tensors = set()

  def AddOp(self, op):  # pylint: disable=invalid-name
    if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
开发者ID:chdinh,项目名称:tensorflow,代码行数:32,代码来源:function.py


示例12: test_enables_nontensor_plumbing

  def test_enables_nontensor_plumbing(self):
    # Setup.

    class Foo(object):

      def __init__(self, input_):
        self._input = input_
        self.value = ops.convert_to_tensor(42.)

      @property
      def dtype(self):
        return self.value.dtype

    ops.register_tensor_conversion_function(
        Foo, lambda x, *args, **kwargs: x.value)
    tf_utils.register_symbolic_tensor_type(Foo)

    class PlumbingLayer(keras.layers.Lambda):

      def __init__(self, fn, **kwargs):
        def _fn(*fargs, **fkwargs):
          d = fn(*fargs, **fkwargs)
          x = ops.convert_to_tensor(d)
          d.shape = x.shape
          d.get_shape = x.get_shape
          return d, x
        super(PlumbingLayer, self).__init__(_fn, **kwargs)
        self._enter_dunder_call = False

      def __call__(self, inputs, *args, **kwargs):
        self._enter_dunder_call = True
        d, _ = super(PlumbingLayer, self).__call__(inputs, *args, **kwargs)
        self._enter_dunder_call = False
        return d

      def call(self, inputs, *args, **kwargs):
        d, v = super(PlumbingLayer, self).call(inputs, *args, **kwargs)
        if self._enter_dunder_call:
          return d, v
        return d

    # User-land.
    model = keras.Sequential([
        keras.layers.InputLayer([]),
        PlumbingLayer(Foo),  # Makes a `Foo` object.
    ])
    # Let's ensure Keras graph history is preserved by composing the models.
    model = keras.Model(model.inputs, model(model.outputs))
    # Now we instantiate the model and verify we have a `Foo` object, not a
    # `Tensor`.
    y = model(ops.convert_to_tensor(7.))
    self.assertIsInstance(y, Foo)
    # Confirm that (custom) loss sees `Foo` instance, not Tensor.
    obtained_prediction_box = [None]
    def custom_loss(y_obs, y_pred):
      del y_obs
      obtained_prediction_box[0] = y_pred
      return y_pred
    # Apparently `compile` calls the loss function enough to trigger the
    # side-effect.
    model.compile('SGD', loss=custom_loss)
    self.assertIsInstance(obtained_prediction_box[0], Foo)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:62,代码来源:tf_utils_test.py


示例13: _lazy_zero_tensor

      tensor.Tensor(shape, dtype=dtypes.int32), tensor.Tensor(1, dtype=dtype))


def _lazy_zero_tensor(zero):
  return _zeros(zero.shape, zero.dtype)


tensor.LazyZero.tensor = _lazy_zero_tensor


def _lazy_zero_to_tensor(lazy_zero, dtype=None, name=None, as_ref=False):
  del as_ref, name, dtype
  return _zeros(lazy_zero.shape, lazy_zero.dtype)


ops.register_tensor_conversion_function(tensor.LazyZero, _lazy_zero_to_tensor)


def _indexed_slices_to_tensor(value):
  """Converts an IndexedSlices object `value` to a Tensor.

  Args:
    value: An ops.IndexedSlices object.

  Returns:
    A dense Tensor representing the values in the given IndexedSlices.

  Raises:
    ValueError: If the IndexedSlices does not have the same dtype.
  """
  if value.dense_shape is None:
开发者ID:keveman,项目名称:tensorflow,代码行数:31,代码来源:tensor_node.py


示例14: _TensorConversionFunction

      delta_get_op = delta_staging_area.get()[0]
    # Return the actual updates. The colocation constraint will be reapplied.
    return self.real_var.assign_sub(delta_get_op)

  @staticmethod
  # pylint: disable=bad-staticmethod-argument,invalid-name
  def _TensorConversionFunction(self, dtype=None, name=None, as_ref=False):
    """Utility function for converting a StagedModelVariable to a Tensor."""
    del dtype, name  # unused: this function returns the cached ref or value.
    if as_ref:
      return self._ref()
    else:
      return self._value()


ops.register_tensor_conversion_function(
    StagedModelVariable, StagedModelVariable._TensorConversionFunction)  # pylint: disable=protected-access


class StagedVariableGetter(object):
  """A variable getter through staging buffers on devices.

  Instead of a caching device, this getter tracks where the variable is used.
  And on each device, it goes through a staging buffer.
  """

  def __init__(self, device_num, devices, cpu_device, variable_mgr):
    """Initializer for StagedVariableGetter.

    Args:
      device_num: the current device index.
      devices: a list of all the devices to build towers.
开发者ID:Ericyuanhui,项目名称:Build_learning,代码行数:32,代码来源:variable_mgr.py


示例15: __init__

  """Base class for asset files which need to be tracked."""

  def __init__(self, path):
    """Record the full path to the asset."""
    # We use a variable here so that @tf.functions do not capture a literal
    # value. The init_scope prevents functions from capturing `path` in an
    # initialization graph, since it is transient and should not end up in a
    # serialized function body. When serialized in a SavedModel, the variable
    # will be set during the loading process to its location in the assets/
    # directory.
    with ops.init_scope():
      if context.executing_eagerly():
        self._path = self._no_dependency(
            resource_variable_ops.ResourceVariable(
                path, dtype=dtypes.string,
                name="asset_path"))
      else:
        # Adding a variable is too disruptive when v1-style graph building,
        # since things may get fed and local variable initializers would then
        # need to be run.
        self._path = path

  @property
  def asset_path(self):
    """Fetch the current asset path."""
    return self._path

ops.register_tensor_conversion_function(
    TrackableAsset,
    lambda asset, **kw: ops.internal_convert_to_tensor(asset.asset_path, **kw))
开发者ID:aeverall,项目名称:tensorflow,代码行数:30,代码来源:tracking.py


示例16: _read_variable_op

    return self._read_variable_op()

  def _read_variable_op(self):
    with ops.control_dependencies([self._parent_op]):
      return gen_resource_variable_ops.read_variable_op(self._handle,
                                                        self._dtype)

  def set_shape(self, shape):
    self._shape = shape

  @property
  def op(self):
    """The op for this variable."""
    return self._parent_op

ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor)
ops.register_dense_tensor_like_type(_UnreadVariable)

# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.

# Note: registering for Variable after ResourceVariable because inheritance will
# otherwise lead to the wrong behavior.
ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
ops.register_tensor_conversion_function(
    variables.Variable, variables.Variable._TensorConversionFunction)  # pylint: disable=protected-access

# pylint: disable=protected-access
ResourceVariable._OverloadAllOperators()
ops.register_dense_tensor_like_type(ResourceVariable)
开发者ID:keithc61,项目名称:tensorflow,代码行数:30,代码来源:resource_variable_ops.py


示例17: _tensor_conversion

      return self.read_value()


# Register a conversion function which reads the value of the variable,
# allowing instances of the class to be used as tensors.
def _tensor_conversion(var, dtype=None, name=None, as_ref=False):
  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access


def replicated_fetch_function(var):
  # pylint: disable=protected-access
  return ([var._dense_var_to_tensor()], lambda v: v[0])
  # pylint: enable=protected-access


ops.register_tensor_conversion_function(ReplicatedVariable, _tensor_conversion)
ops.register_dense_tensor_like_type(ReplicatedVariable)
session_lib.register_session_run_conversion_functions(
    ReplicatedVariable, replicated_fetch_function)


def replicated_scope(num_replicas):
  """Variable scope for constructing replicated variables."""

  def _replicated_variable_getter(getter, name, *args, **kwargs):
    """Getter that constructs replicated variables."""
    collections = kwargs.pop("collections", None)
    if collections is None:
      collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    kwargs["collections"] = []
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:keras_tpu_variables.py


示例18: NotImplementedError

    raise NotImplementedError("surrogate_loss not implemented")

  @staticmethod
  def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False):
    _ = name
    if dtype and not dtype.is_compatible_with(v.dtype):
      raise ValueError(
          "Incompatible type conversion requested to type '%s' for variable "
          "of type '%s'" % (dtype.name, v.dtype.name))
    if as_ref:
      raise ValueError("%s: Ref type is not supported." % v)
    return v.value()


# pylint: disable=protected-access
ops.register_tensor_conversion_function(
    StochasticTensor, StochasticTensor._tensor_conversion_function)
# pylint: enable=protected-access


class _StochasticValueType(object):
  """Interface for the ValueType classes.

  This is the base class for MeanValue, SampleValue, SampleAndReshapeValue,
  and their descendants.
  """

  def pushed_above(self, unused_value_type):
    pass

  def popped_above(self, unused_value_type):
    pass
开发者ID:10imaging,项目名称:tensorflow,代码行数:32,代码来源:stochastic_graph.py


示例19: TODO

      compute_device=True):
    # TODO(apassos) this should do some form of alias analysis as ops which
    # forward the resources such as Identity and Switch can cause serialization
    # to fail.
    for i, inp in enumerate(inputs):
      if inp.graph is not self:
        inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name)
    return super(CapturingGraph, self).create_op(
        op_type, inputs, dtypes, input_types, name, attrs, op_def,
        compute_shapes, compute_device)


# TODO(apassos): it'd be really nice if we could scope this registration.
# Note that we register this at a higher priority than ops.Tensor since we want
# to handle subclass specific conversion before a superclass conversion.
ops.register_tensor_conversion_function(
    ops.EagerTensor, _convert_to_graph_tensor, priority=-1)


# pylint: disable=invalid-name
class HelperContext(object):
  """ControlFlowContext with a customizable AddOp method."""

  def __init__(self, add_op_internal):
    self._add_op_internal = add_op_internal
    self._values = set()  # control flow code sometimes updates this.

  def _AddOpInternal(self, op):
    self._add_op_internal(op)

  @property
  def outer_context(self):
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:32,代码来源:function.py


示例20: _MarkReachedOps

  dense_shape_value = tensor_util.constant_value(value.dense_shape)
  if dense_shape_value is not None:
    num_elements = np.prod(dense_shape_value)
    if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
      warnings.warn(
          "Converting sparse IndexedSlices to a dense Tensor with %d elements. "
          "This may consume a large amount of memory." % num_elements)
  else:
    warnings.warn(
        "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
        "This may consume a large amount of memory.")
  return math_ops.unsorted_segment_sum(
      value.values, value.indices, value.dense_shape[0], name=name)


ops.register_tensor_conversion_function(ops.IndexedSlices,
                                        _IndexedSlicesToTensor)


def _MarkReachedOps(from_ops, reached_ops):
  """Mark all ops reached from "from_ops".

  Args:
    from_ops: list of Operations.
    reached_ops: list of booleans, indexed by operation id.
  """
  queue = collections.deque()
  queue.extend(from_ops)
  while queue:
    op = queue.popleft()
    if not reached_ops[op._id]:
      reached_ops[op._id] = True
开发者ID:kdavis-mozilla,项目名称:tensorflow,代码行数:32,代码来源:gradients_impl.py



注:本文中的tensorflow.python.framework.ops.register_tensor_conversion_function函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python ops.reset_default_graph函数代码示例发布时间:2022-05-27
下一篇:
Python ops.register_proto_function函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap