• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python common_shapes.broadcast_shape函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.framework.common_shapes.broadcast_shape函数的典型用法代码示例。如果您正苦于以下问题:Python broadcast_shape函数的具体用法?Python broadcast_shape怎么用?Python broadcast_shape使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了broadcast_shape函数的18个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: _assert_incompatible_broadcast

 def _assert_incompatible_broadcast(self, shape1, shape2):
   if shape1.dims is not None and shape2.dims is not None:
     zeros1 = np.zeros(shape1.as_list())
     zeros2 = np.zeros(shape2.as_list())
     with self.assertRaises(ValueError):
       np.broadcast(zeros1, zeros2)
     with self.assertRaises(ValueError):
       np.broadcast(zeros2, zeros1)
   with self.assertRaises(ValueError):
     common_shapes.broadcast_shape(shape1, shape2)
   with self.assertRaises(ValueError):
     common_shapes.broadcast_shape(shape2, shape1)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:12,代码来源:common_shapes_test.py


示例2: _assert_broadcast

 def _assert_broadcast(self, expected, shape1, shape2):
   if shape1.dims is not None and shape2.dims is not None:
     expected_np = expected.as_list()
     zeros1 = np.zeros(shape1.as_list())
     zeros2 = np.zeros(shape2.as_list())
     self.assertAllEqual(expected_np, np.broadcast(zeros1, zeros2).shape)
     self.assertAllEqual(expected_np, np.broadcast(zeros2, zeros1).shape)
     self.assertEqual(
         expected, common_shapes.broadcast_shape(shape1, shape2))
     self.assertEqual(
         expected, common_shapes.broadcast_shape(shape2, shape1))
   else:
     self.assertEqual(expected, common_shapes.broadcast_shape(shape1, shape2))
     self.assertEqual(expected, common_shapes.broadcast_shape(shape2, shape1))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:14,代码来源:common_shapes_test.py


示例3: __init__

  def __init__(self,
               df,
               mu,
               sigma,
               validate_args=True,
               allow_nan_stats=False,
               name="StudentT"):
    """Construct Student's t distributions.

    The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`.

    The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `df + mu + sigma` is a valid operation).

    Args:
      df: Floating point tensor, the degrees of freedom of the
        distribution(s). `df` must contain only positive values.
      mu: Floating point tensor, the means of the distribution(s).
      sigma: Floating point tensor, the scaling factor for the
        distribution(s). `sigma` must contain only positive values.
        Note that `sigma` is not the standard deviation of this distribution.
      validate_args: Whether to assert that `df > 0, sigma > 0`. If
        `validate_args` is `False` and inputs are invalid, correct behavior is
        not guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.name_scope(name, values=[df, mu, sigma]) as scope:
      with ops.control_dependencies([check_ops.assert_positive(
          df), check_ops.assert_positive(sigma)] if validate_args else []):
        self._df = ops.convert_to_tensor(df, name="df")
        self._mu = ops.convert_to_tensor(mu, name="mu")
        self._sigma = ops.convert_to_tensor(sigma, name="sigma")
        contrib_tensor_util.assert_same_float_dtype(
            (self._df, self._mu, self._sigma))
      self._name = scope
      self._get_batch_shape = common_shapes.broadcast_shape(
          self._sigma.get_shape(), common_shapes.broadcast_shape(
              self._df.get_shape(), self._mu.get_shape()))
      self._get_event_shape = tensor_shape.TensorShape([])
开发者ID:alephman,项目名称:Tensorflow,代码行数:48,代码来源:student_t.py


示例4: __init__

  def __init__(self,
               a=0.0,
               b=1.0,
               validate_args=True,
               allow_nan_stats=False,
               name="Uniform"):
    """Construct Uniform distributions with `a` and `b`.

    The parameters `a` and `b` must be shaped in a way that supports
    broadcasting (e.g. `b - a` is a valid operation).

    Here are examples without broadcasting:

    ```python
    # Without broadcasting
    u1 = Uniform(3.0, 4.0)  # a single uniform distribution [3, 4]
    u2 = Uniform([1.0, 2.0], [3.0, 4.0])  # 2 distributions [1, 3], [2, 4]
    u3 = Uniform([[1.0, 2.0],
                  [3.0, 4.0]],
                 [[1.5, 2.5],
                  [3.5, 4.5]])  # 4 distributions
    ```

    And with broadcasting:

    ```python
    u1 = Uniform(3.0, [5.0, 6.0, 7.0])  # 3 distributions
    ```

    Args:
      a: Floating point tensor, the minimum endpoint.
      b: Floating point tensor, the maximum endpoint. Must be > `a`.
      validate_args: Whether to assert that `a > b`. If `validate_args` is
        `False` and inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Raises:
      InvalidArgumentError: if `a >= b` and `validate_args=True`.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.name_scope(name, values=[a, b]):
      with ops.control_dependencies([check_ops.assert_less(
          a, b, message="uniform not defined when a > b.")] if validate_args
                                    else []):
        a = array_ops.identity(a, name="a")
        b = array_ops.identity(b, name="b")

    self._a = a
    self._b = b
    self._name = name
    self._batch_shape = common_shapes.broadcast_shape(
        self._a.get_shape(), self._b.get_shape())
    self._event_shape = tensor_shape.TensorShape([])

    contrib_tensor_util.assert_same_float_dtype((a, b))
开发者ID:alephman,项目名称:Tensorflow,代码行数:60,代码来源:uniform.py


示例5: sample_n

  def sample_n(self, n, seed=None, name="sample_n"):
    """Sample `n` observations from the Normal Distributions.

    Args:
      n: `Scalar`, type int32, the number of observations to sample.
      seed: Python integer, the random seed.
      name: The name to give this op.

    Returns:
      samples: `[n, ...]`, a `Tensor` of `n` samples for each
        of the distributions determined by broadcasting the hyperparameters.
    """
    with ops.name_scope(self.name):
      with ops.name_scope(name, values=[self._mu, self._sigma, n]):
        broadcast_shape = common_shapes.broadcast_shape(
            self._mu.get_shape(), self._sigma.get_shape())
        n = ops.convert_to_tensor(n)
        shape = array_ops.concat(0, ([n], array_ops.shape(self.mean())))
        sampled = random_ops.random_normal(
            shape=shape, mean=0, stddev=1, dtype=self._mu.dtype, seed=seed)

        # Provide some hints to shape inference
        n_val = tensor_util.constant_value(n)
        final_shape = tensor_shape.vector(n_val).concatenate(broadcast_shape)
        sampled.set_shape(final_shape)

        return sampled * self._sigma + self._mu
开发者ID:alephman,项目名称:Tensorflow,代码行数:27,代码来源:normal.py


示例6: __init__

  def __init__(self, shape, dtype, minimum, maximum, name=None):
    """Initializes a new `BoundedTensorSpec`.

    Args:
      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
      dtype: Value convertible to `tf.DType`. The type of the tensor values.
      minimum: Number or sequence specifying the minimum element bounds
        (inclusive). Must be broadcastable to `shape`.
      maximum: Number or sequence specifying the maximum element bounds
        (inclusive). Must be broadcastable to `shape`.
      name: Optional string containing a semantic name for the corresponding
        array. Defaults to `None`.

    Raises:
      ValueError: If `minimum` or `maximum` are not provided or not
        broadcastable to `shape`.
      TypeError: If the shape is not an iterable or if the `dtype` is an invalid
        numpy dtype.
    """
    super(BoundedTensorSpec, self).__init__(shape, dtype, name)

    if minimum is None or maximum is None:
      raise ValueError("minimum and maximum must be provided; but saw "
                       "'%s' and '%s'" % (minimum, maximum))

    try:
      minimum_shape = np.shape(minimum)
      common_shapes.broadcast_shape(
          tensor_shape.TensorShape(minimum_shape), self.shape)
    except ValueError as exception:
      raise ValueError("minimum is not compatible with shape. "
                       "Message: {!r}.".format(exception))

    try:
      maximum_shape = np.shape(maximum)
      common_shapes.broadcast_shape(
          tensor_shape.TensorShape(maximum_shape), self.shape)
    except ValueError as exception:
      raise ValueError("maximum is not compatible with shape. "
                       "Message: {!r}.".format(exception))

    self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype())
    self._minimum.setflags(write=False)

    self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype())
    self._maximum.setflags(write=False)
开发者ID:PuchatekwSzortach,项目名称:tensorflow,代码行数:46,代码来源:tensor_spec.py


示例7: _assert_broadcast_with_unknown_dims

  def _assert_broadcast_with_unknown_dims(self, expected, shape1, shape2):
    actual_dims = common_shapes.broadcast_shape(shape1, shape2).dims
    reflexive_actual_dims = common_shapes.broadcast_shape(shape2, shape1).dims

    if actual_dims is None:
      self.assertIsNone(reflexive_actual_dims)
    elif reflexive_actual_dims is None:
      self.assertIsNone(actual_dims)
    else:
      self.assertEqual(len(actual_dims), len(reflexive_actual_dims))
      for actual_dim, reflexive_actual_dim in zip(
          actual_dims, reflexive_actual_dims):
        self.assertEqual(actual_dim.value, reflexive_actual_dim.value)

    expected_dims = expected.dims
    if expected_dims is None:
      self.assertIsNone(actual_dims)
    elif actual_dims is None:
      self.assertIsNone(expected_dims)
    else:
      self.assertEqual(len(expected_dims), len(actual_dims))
      for expected_dim, actual_dim in zip(expected_dims, actual_dims):
        self.assertEqual(expected_dim.value, actual_dim.value)
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:23,代码来源:common_shapes_test.py


示例8: __init__

  def __init__(self,
               alpha,
               beta,
               validate_args=True,
               allow_nan_stats=False,
               name="Gamma"):
    """Construct Gamma distributions with parameters `alpha` and `beta`.

    The parameters `alpha` and `beta` must be shaped in a way that supports
    broadcasting (e.g. `alpha + beta` is a valid operation).

    Args:
      alpha: Floating point tensor, the shape params of the
        distribution(s).
        alpha must contain only positive values.
      beta: Floating point tensor, the inverse scale params of the
        distribution(s).
        beta must contain only positive values.
      validate_args: Whether to assert that `a > 0, b > 0`, and that `x > 0` in
        the methods `prob(x)` and `log_prob(x)`.  If `validate_args` is `False`
        and the inputs are invalid, correct behavior is not guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prepend to all ops created by this distribution.

    Raises:
      TypeError: if `alpha` and `beta` are different dtypes.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.name_scope(name, values=[alpha, beta]) as scope:
      self._name = scope
      with ops.control_dependencies([check_ops.assert_positive(
          alpha), check_ops.assert_positive(beta)] if validate_args else []):
        alpha = array_ops.identity(alpha, name="alpha")
        beta = array_ops.identity(beta, name="beta")

    self._get_batch_shape = common_shapes.broadcast_shape(
        alpha.get_shape(), beta.get_shape())
    self._get_event_shape = tensor_shape.TensorShape([])

    self._alpha = alpha
    self._beta = beta
开发者ID:alephman,项目名称:Tensorflow,代码行数:45,代码来源:gamma.py


示例9: _shape

  def _shape(self):
    # Get final matrix shape.
    domain_dimension = self.operators[0].domain_dimension
    range_dimension = self.operators[0].range_dimension
    for operator in self.operators[1:]:
      domain_dimension += operator.domain_dimension
      range_dimension += operator.range_dimension

    matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])

    # Get broadcast batch shape.
    # broadcast_shape checks for compatibility.
    batch_shape = self.operators[0].batch_shape
    for operator in self.operators[1:]:
      batch_shape = common_shapes.broadcast_shape(
          batch_shape, operator.batch_shape)

    return batch_shape.concatenate(matrix_shape)
开发者ID:aritratony,项目名称:tensorflow,代码行数:18,代码来源:linear_operator_block_diag.py


示例10: __init__

  def __init__(self,
               loc,
               scale,
               validate_args=True,
               allow_nan_stats=False,
               name="Laplace"):
    """Construct Laplace distribution with parameters `loc` and `scale`.

    The parameters `loc` and `scale` must be shaped in a way that supports
    broadcasting (e.g., `loc / scale` is a valid operation).

    Args:
      loc: Floating point tensor which characterizes the location (center)
        of the distribution.
      scale: Positive floating point tensor which characterizes the spread of
        the distribution.
      validate_args: Whether to validate input with asserts.  If `validate_args`
        is `False`, and the inputs are invalid, correct behavior is not
        guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if `loc` and `scale` are of different dtype.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.name_scope(name, values=[loc, scale]):
      loc = ops.convert_to_tensor(loc)
      scale = ops.convert_to_tensor(scale)
      with ops.control_dependencies([check_ops.assert_positive(scale)] if
                                    validate_args else []):
        self._name = name
        self._loc = array_ops.identity(loc, name="loc")
        self._scale = array_ops.identity(scale, name="scale")
        self._batch_shape = common_shapes.broadcast_shape(
            self._loc.get_shape(), self._scale.get_shape())
        self._event_shape = tensor_shape.TensorShape([])

    contrib_tensor_util.assert_same_float_dtype((loc, scale))
开发者ID:alephman,项目名称:Tensorflow,代码行数:43,代码来源:laplace.py


示例11: __init__

  def __init__(self,
               mu,
               sigma,
               validate_args=True,
               allow_nan_stats=False,
               name="Normal"):
    """Construct Normal distributions with mean and stddev `mu` and `sigma`.

    The parameters `mu` and `sigma` must be shaped in a way that supports
    broadcasting (e.g. `mu + sigma` is a valid operation).

    Args:
      mu: Floating point tensor, the means of the distribution(s).
      sigma: Floating point tensor, the stddevs of the distribution(s).
        sigma must contain only positive values.
      validate_args: Whether to assert that `sigma > 0`. If `validate_args` is
        `False`, correct output is not guaranteed when input is invalid.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to give Ops created by the initializer.

    Raises:
      TypeError: if mu and sigma are different dtypes.
    """
    self._allow_nan_stats = allow_nan_stats
    self._validate_args = validate_args
    with ops.name_scope(name, values=[mu, sigma]):
      mu = ops.convert_to_tensor(mu)
      sigma = ops.convert_to_tensor(sigma)
      with ops.control_dependencies([check_ops.assert_positive(sigma)] if
                                    validate_args else []):
        self._name = name
        self._mu = array_ops.identity(mu, name="mu")
        self._sigma = array_ops.identity(sigma, name="sigma")
        self._batch_shape = common_shapes.broadcast_shape(
            self._mu.get_shape(), self._sigma.get_shape())
        self._event_shape = tensor_shape.TensorShape([])

    contrib_tensor_util.assert_same_float_dtype((mu, sigma))
开发者ID:alephman,项目名称:Tensorflow,代码行数:41,代码来源:normal.py


示例12: _get_batch_shape

 def _get_batch_shape(self):
   return common_shapes.broadcast_shape(self.loc.get_shape(),
                                        self.scale.get_shape())
开发者ID:KalraA,项目名称:tensorflow,代码行数:3,代码来源:laplace.py


示例13: _get_batch_shape

 def _get_batch_shape(self):
   return common_shapes.broadcast_shape(self.n.get_shape(),
                                        self.p.get_shape())
开发者ID:bsantanas,项目名称:tensorflow,代码行数:3,代码来源:binomial.py


示例14: _get_batch_shape

 def _get_batch_shape(self):
   return common_shapes.broadcast_shape(self.alpha.get_shape(),
                                        self.beta.get_shape())
开发者ID:Nishant23,项目名称:tensorflow,代码行数:3,代码来源:gamma.py


示例15: _get_batch_shape

 def _get_batch_shape(self):
   return common_shapes.broadcast_shape(
       self._mu.get_shape(), self.sigma.get_shape())
开发者ID:curtiszimmerman,项目名称:tensorflow,代码行数:3,代码来源:normal.py


示例16: _matmul

  def _matmul(self, x, adjoint=False, adjoint_arg=False):
    # Here we heavily rely on Roth's column Lemma [1]:
    # (A x B) * vec X = vec BXA^T,
    # where vec stacks all the columns of the matrix under each other. In our
    # case, x represents a batch of vec X (i.e. we think of x as a batch of
    # column vectors, rather than a matrix). Each member of the batch can be
    # reshaped to a matrix (hence we get a batch of matrices).
    # We can iteratively apply this lemma by noting that if B is a Kronecker
    # product, then we can apply the lemma again.

    # [1] W. E. Roth, "On direct product matrices,"
    # Bulletin of the American Mathematical Society, vol. 40, pp. 461-468,
    # 1934

    # Efficiency

    # Naively doing the Kronecker product, by calculating the dense matrix and
    # applying it will can take cubic time in  the size of domain_dimension
    # (assuming a square matrix). The other issue is that calculating the dense
    # matrix can be prohibitively expensive, in that it can take a large amount
    # of memory.
    #
    # This implementation avoids this memory blow up by only computing matmuls
    # with the factors. In this way, we don't have to realize the dense matrix.
    # In terms of complexity, if we have Kronecker Factors of size:
    # (n1, n1), (n2, n2), (n3, n3), ... (nJ, nJ), with N = \prod n_i, and we
    # have as input a [N, M] matrix, the naive approach would take O(N^2 M).
    # With this approach (ignoring reshaping of tensors and transposes for now),
    # the time complexity can be O(M * (\sum n_i) * N). There is also the
    # benefit of batched multiplication (In this example, the batch size is
    # roughly M * N) so this can be much faster. However, not factored in are
    # the costs of the several transposing of tensors, which can affect cache
    # behavior.

    # Below we document the shape manipulation for adjoint=False,
    # adjoint_arg=False, but the general case of different adjoints is still
    # handled.

    if adjoint_arg:
      x = linalg.adjoint(x)

    # Always add a batch dimension to enable broadcasting to work.
    batch_shape = array_ops.concat(
        [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
    x += array_ops.zeros(batch_shape, dtype=x.dtype.base_dtype)

    # x has shape [B, R, C], where B represent some number of batch dimensions,
    # R represents the number of rows, and C represents the number of columns.
    # In order to apply Roth's column lemma, we need to operate on a batch of
    # column vectors, so we reshape into a batch of column vectors. We put it
    # at the front to ensure that broadcasting between operators to the batch
    # dimensions B still works.
    output = _rotate_last_dim(x, rotate_right=True)

    # Also expand the shape to be [A, C, B, R]. The first dimension will be
    # used to accumulate dimensions from each operator matmul.
    output = output[array_ops.newaxis, ...]

    # In this loop, A is going to refer to the value of the accumulated
    # dimension. A = 1 at the start, and will end up being self.range_dimension.
    # V will refer to the last dimension. V = R at the start, and will end up
    # being 1 in the end.
    for operator in self.operators[:-1]:
      # Reshape output from [A, C, B, V] to be
      # [A, C, B, V / op.domain_dimension, op.domain_dimension]
      if adjoint:
        operator_dimension = operator.range_dimension_tensor()
      else:
        operator_dimension = operator.domain_dimension_tensor()

      output = _unvec_by(output, operator_dimension)

      # We are computing (XA^T) = (AX^T)^T.
      # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
      # which is being converted to:
      # [A, C, B, V / op.domain_dimension, op.range_dimension]
      output = array_ops.matrix_transpose(output)
      output = operator.matmul(output, adjoint=adjoint, adjoint_arg=False)
      output = array_ops.matrix_transpose(output)
      # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
      output = _rotate_last_dim(output, rotate_right=False)
      output = _vec(output)
      output = _rotate_last_dim(output, rotate_right=True)

    # After the loop, we will have
    # A = self.range_dimension / op[-1].range_dimension
    # V = op[-1].domain_dimension

    # We convert that using matvec to get:
    # [A, C, B, op[-1].range_dimension]
    output = self.operators[-1].matvec(output, adjoint=adjoint)
    # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
    output = _rotate_last_dim(output, rotate_right=False)
    output = _vec(output)
    output = _rotate_last_dim(output, rotate_right=False)

    if x.shape.is_fully_defined():
      column_dim = x.shape[-1]
      broadcast_batch_shape = common_shapes.broadcast_shape(
          x.shape[:-2], self.batch_shape)
#.........这里部分代码省略.........
开发者ID:aritratony,项目名称:tensorflow,代码行数:101,代码来源:linear_operator_kronecker.py


示例17: _solve

  def _solve(self, rhs, adjoint=False, adjoint_arg=False):
    # Here we follow the same use of Roth's column lemma as in `matmul`, with
    # the key difference that we replace all `matmul` instances with `solve`.
    # This follows from the property that inv(A x B) = inv(A) x inv(B).

    # Below we document the shape manipulation for adjoint=False,
    # adjoint_arg=False, but the general case of different adjoints is still
    # handled.

    if adjoint_arg:
      rhs = linalg.adjoint(rhs)

    # Always add a batch dimension to enable broadcasting to work.
    batch_shape = array_ops.concat(
        [array_ops.ones_like(self.batch_shape_tensor()), [1, 1]], 0)
    rhs += array_ops.zeros(batch_shape, dtype=rhs.dtype.base_dtype)

    # rhs has shape [B, R, C], where B represent some number of batch
    # dimensions,
    # R represents the number of rows, and C represents the number of columns.
    # In order to apply Roth's column lemma, we need to operate on a batch of
    # column vectors, so we reshape into a batch of column vectors. We put it
    # at the front to ensure that broadcasting between operators to the batch
    # dimensions B still works.
    output = _rotate_last_dim(rhs, rotate_right=True)

    # Also expand the shape to be [A, C, B, R]. The first dimension will be
    # used to accumulate dimensions from each operator matmul.
    output = output[array_ops.newaxis, ...]

    # In this loop, A is going to refer to the value of the accumulated
    # dimension. A = 1 at the start, and will end up being self.range_dimension.
    # V will refer to the last dimension. V = R at the start, and will end up
    # being 1 in the end.
    for operator in self.operators[:-1]:
      # Reshape output from [A, C, B, V] to be
      # [A, C, B, V / op.domain_dimension, op.domain_dimension]
      if adjoint:
        operator_dimension = operator.range_dimension_tensor()
      else:
        operator_dimension = operator.domain_dimension_tensor()

      output = _unvec_by(output, operator_dimension)

      # We are computing (XA^-1^T) = (A^-1 X^T)^T.
      # output has [A, C, B, V / op.domain_dimension, op.domain_dimension],
      # which is being converted to:
      # [A, C, B, V / op.domain_dimension, op.range_dimension]
      output = array_ops.matrix_transpose(output)
      output = operator.solve(output, adjoint=adjoint, adjoint_arg=False)
      output = array_ops.matrix_transpose(output)
      # Rearrange it to [A * op.range_dimension, C, B, V / op.domain_dimension]
      output = _rotate_last_dim(output, rotate_right=False)
      output = _vec(output)
      output = _rotate_last_dim(output, rotate_right=True)

    # After the loop, we will have
    # A = self.range_dimension / op[-1].range_dimension
    # V = op[-1].domain_dimension

    # We convert that using matvec to get:
    # [A, C, B, op[-1].range_dimension]
    output = self.operators[-1].solvevec(output, adjoint=adjoint)
    # Rearrange shape to be [B1, ... Bn, self.range_dimension, C]
    output = _rotate_last_dim(output, rotate_right=False)
    output = _vec(output)
    output = _rotate_last_dim(output, rotate_right=False)

    if rhs.shape.is_fully_defined():
      column_dim = rhs.shape[-1]
      broadcast_batch_shape = common_shapes.broadcast_shape(
          rhs.shape[:-2], self.batch_shape)
      if adjoint:
        matrix_dimensions = [self.domain_dimension, column_dim]
      else:
        matrix_dimensions = [self.range_dimension, column_dim]

      output.set_shape(broadcast_batch_shape.concatenate(
          matrix_dimensions))

    return output
开发者ID:aritratony,项目名称:tensorflow,代码行数:81,代码来源:linear_operator_kronecker.py


示例18: __init__

  def __init__(self,
               n,
               logits=None,
               p=None,
               validate_args=True,
               allow_nan_stats=False,
               name="Binomial"):
    """Initialize a batch of Binomial distributions.

    Args:
      n:  Non-negative floating point tensor with shape broadcastable to
        `[N1,..., Nm]` with `m >= 0` and the same dtype as `p` or `logits`.
        Defines this as a batch of `N1 x ... x Nm` different Binomial
        distributions. Its components should be equal to integer values.
      logits: Floating point tensor representing the log-odds of a
        positive event with shape broadcastable to `[N1,..., Nm]` `m >= 0`, and
        the same dtype as `n`. Each entry represents logits for the probability
        of success for independent Binomial distributions.
      p:  Positive floating point tensor with shape broadcastable to
        `[N1,..., Nm]` `m >= 0`, `p in [0, 1]`. Each entry represents the
        probability of success for independent Binomial distributions.
      validate_args: Whether to assert valid values for parameters `n` and `p`,
        and `x` in `prob` and `log_prob`.  If `False`, correct behavior is not
        guaranteed.
      allow_nan_stats:  Boolean, default `False`.  If `False`, raise an
        exception if a statistic (e.g. mean/mode/etc...) is undefined for any
        batch member.  If `True`, batch members with valid parameters leading to
        undefined statistics will return NaN for this statistic.
      name: The name to prefix Ops created by this distribution class.

    Examples:

    ```python
    # Define 1-batch of a binomial distribution.
    dist = Binomial(n=2., p=.9)

    # Define a 2-batch.
    dist = Binomial(n=[4., 5], p=[.1, .3])
    ```

    """

    self._logits, self._p = distribution_util.get_logits_and_prob(
        name=name, logits=logits, p=p, validate_args=validate_args)

    with ops.name_scope(name, values=[n]):
      with ops.control_dependencies([
          check_ops.assert_non_negative(
              n, message="n has negative components."),
          distribution_util.assert_integer_form(
              n, message="n has non-integer components."
          )] if validate_args else []):
        self._n = array_ops.identity(n, name="convert_n")

        self._name = name
        self._validate_args = validate_args
        self._allow_nan_stats = allow_nan_stats

        self._get_batch_shape = common_shapes.broadcast_shape(
            self._n.get_shape(), self._p.get_shape())
        self._get_event_shape = tensor_shape.TensorShape([])
开发者ID:alephman,项目名称:Tensorflow,代码行数:61,代码来源:binomial.py



注:本文中的tensorflow.python.framework.common_shapes.broadcast_shape函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python constant_op.constant函数代码示例发布时间:2022-05-27
下一篇:
Python c_api_util.tf_buffer函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap