• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python check_ops.assert_equal函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.check_ops.assert_equal函数的典型用法代码示例。如果您正苦于以下问题:Python assert_equal函数的具体用法?Python assert_equal怎么用?Python assert_equal使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了assert_equal函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: call

  def call(self, labels, predictions, weights=None):
    """Accumulate accuracy statistics.

    For example, if labels is [1, 2, 3, 4] and predictions is [0, 2, 3, 4]
    then the accuracy is 3/4 or .75.  If the weights were specified as
    [1, 1, 0, 0] then the accuracy would be 1/2 or .5.

    `labels` and `predictions` should have the same shape and type.

    Args:
      labels: Tensor with the true labels for each example.  One example
        per element of the Tensor.
      predictions: Tensor with the predicted label for each example.
      weights: Optional weighting of each example. Defaults to 1.

    Returns:
      The arguments, for easy chaining.
    """
    check_ops.assert_equal(
        array_ops.shape(labels), array_ops.shape(predictions),
        message="Shapes of labels and predictions are unequal")
    matches = math_ops.equal(labels, predictions)
    matches = math_ops.cast(matches, self.dtype)
    super(Accuracy, self).call(matches, weights=weights)
    if weights is None:
      return labels, predictions
    return labels, predictions, weights
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:27,代码来源:metrics_impl.py


示例2: _check_shapes_dynamic

  def _check_shapes_dynamic(self, operator, v, diag):
    """Return (v, diag) with Assert dependencies, which check shape."""
    checks = []
    with ops.op_scope([operator, v, diag], 'check_shapes'):
      s_v = array_ops.shape(v)
      r_op = operator.rank()
      r_v = array_ops.rank(v)
      if diag is not None:
        s_d = array_ops.shape(diag)
        r_d = array_ops.rank(diag)

      # Check tensor rank.
      checks.append(check_ops.assert_rank(v, r_op))
      if diag is not None:
        checks.append(check_ops.assert_rank(diag, r_op - 1))

      # Check batch shape
      checks.append(check_ops.assert_equal(
          operator.batch_shape(), array_ops.slice(s_v, [0], [r_v - 2])))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            operator.batch_shape(), array_ops.slice(s_d, [0], [r_d - 1])))

      # Check event shape
      checks.append(check_ops.assert_equal(
          operator.vector_space_dimension(), array_ops.gather(s_v, r_v - 2)))
      if diag is not None:
        checks.append(check_ops.assert_equal(
            array_ops.gather(s_v, r_v - 1), array_ops.gather(s_d, r_d - 1)))

      v = control_flow_ops.with_dependencies(checks, v)
      if diag is not None:
        diag = control_flow_ops.with_dependencies(checks, diag)
      return v, diag
开发者ID:10imaging,项目名称:tensorflow,代码行数:34,代码来源:operator_pd_vdvt_update.py


示例3: _check_mu

  def _check_mu(self, mu):
    """Return `mu` after validity checks and possibly with assertations."""
    mu = ops.convert_to_tensor(mu)
    cov = self._cov

    if mu.dtype != cov.dtype:
      raise TypeError(
          "mu and cov must have the same dtype.  Found mu.dtype = %s, "
          "cov.dtype = %s"
          % (mu.dtype, cov.dtype))
    if not self.strict:
      return mu
    else:
      assert_compatible_shapes = control_flow_ops.group(
          check_ops.assert_equal(
              array_ops.rank(mu) + 1,
              cov.rank(),
              data=["mu should have rank 1 less than cov.  Found: rank(mu) = ",
                    array_ops.rank(mu), " rank(cov) = ", cov.rank()],
          ),
          check_ops.assert_equal(
              array_ops.shape(mu),
              cov.vector_shape(),
              data=["mu.shape and cov.shape[:-1] should match.  "
                    "Found: shape(mu) = "
                    , array_ops.shape(mu), " shape(cov) = ", cov.shape()],
          ),
      )
      return control_flow_ops.with_dependencies([assert_compatible_shapes], mu)
开发者ID:363158858,项目名称:tensorflow,代码行数:29,代码来源:mvn.py


示例4: _get_sparse_tensors

 def _get_sparse_tensors(self, inputs, weight_collections=None,
                         trainable=None):
   sparse_tensors = self.categorical_column._get_sparse_tensors(inputs)
   id_tensor = sparse_tensors.id_tensor
   weight_tensor = sparse_tensors.weight_tensor
   # Expands final dimension, so that embeddings are not combined during
   # embedding lookup.
   check_id_rank = check_ops.assert_equal(
       array_ops.rank(id_tensor), 2,
       data=[
           'Column {} expected ID tensor of rank 2. '.format(self.name),
           'id_tensor shape: ', array_ops.shape(id_tensor)])
   with ops.control_dependencies([check_id_rank]):
     id_tensor = sparse_ops.sparse_reshape(
         id_tensor,
         shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
   if weight_tensor is not None:
     check_weight_rank = check_ops.assert_equal(
         array_ops.rank(weight_tensor), 2,
         data=[
             'Column {} expected weight tensor of rank 2.'.format(self.name),
             'weight_tensor shape:', array_ops.shape(weight_tensor)])
     with ops.control_dependencies([check_weight_rank]):
       weight_tensor = sparse_ops.sparse_reshape(
           weight_tensor,
           shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
   return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
开发者ID:DILASSS,项目名称:tensorflow,代码行数:27,代码来源:sequential_feature_column.py


示例5: _kl_independent

def _kl_independent(a, b, name="kl_independent"):
  """Batched KL divergence `KL(a || b)` for Independent distributions.

  We can leverage the fact that
  ```
  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
  ```
  where the sum is over the `reinterpreted_batch_ndims`.

  Args:
    a: Instance of `Independent`.
    b: Instance of `Independent`.
    name: (optional) name to use for created ops. Default "kl_independent".

  Returns:
    Batchwise `KL(a || b)`.

  Raises:
    ValueError: If the event space for `a` and `b`, or their underlying
      distributions don't match.
  """
  p = a.distribution
  q = b.distribution

  # The KL between any two (non)-batched distributions is a scalar.
  # Given that the KL between two factored distributions is the sum, i.e.
  # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
  # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
  if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined():
    if a.event_shape == b.event_shape:
      if p.event_shape == q.event_shape:
        num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims
        reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]

        return math_ops.reduce_sum(
            kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
      else:
        raise NotImplementedError("KL between Independents with different "
                                  "event shapes not supported.")
    else:
      raise ValueError("Event shapes do not match.")
  else:
    with ops.control_dependencies([
        check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()),
        check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor())
    ]):
      num_reduce_dims = (
          array_ops.shape(a.event_shape_tensor()[0]) -
          array_ops.shape(p.event_shape_tensor()[0]))
      reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1)
      return math_ops.reduce_sum(
          kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
开发者ID:didukhle,项目名称:tensorflow,代码行数:52,代码来源:independent.py


示例6: test_raises_when_less

 def test_raises_when_less(self):
   with self.test_session():
     # Static check
     static_small = constant_op.constant([3, 1], name="small")
     static_big = constant_op.constant([4, 2], name="big")
     with self.assertRaisesRegexp(ValueError, "fail"):
       check_ops.assert_equal(static_big, static_small, message="fail")
     # Dynamic check
     small = array_ops.placeholder(dtypes.int32, name="small")
     big = array_ops.placeholder(dtypes.int32, name="big")
     with ops.control_dependencies([check_ops.assert_equal(small, big)]):
       out = array_ops.identity(small)
     with self.assertRaisesOpError("small.*big"):
       out.eval(feed_dict={small: [3, 1], big: [4, 2]})
开发者ID:1000sprites,项目名称:tensorflow,代码行数:14,代码来源:check_ops_test.py


示例7: _check_mu

    def _check_mu(self, mu):
        """Return `mu` after validity checks and possibly with assertations."""
        mu = ops.convert_to_tensor(mu)
        cov = self._cov

        if mu.dtype != cov.dtype:
            raise TypeError(
                "mu and cov must have the same dtype.  Found mu.dtype = %s, " "cov.dtype = %s" % (mu.dtype, cov.dtype)
            )

        # Try to validate with static checks.
        mu_shape = mu.get_shape()
        cov_shape = cov.get_shape()
        if mu_shape.is_fully_defined() and cov_shape.is_fully_defined():
            if mu_shape != cov_shape[:-1]:
                raise ValueError(
                    "mu.shape and cov.shape[:-1] should match.  Found: mu.shape=%s, "
                    "cov.shape=%s" % (mu_shape, cov_shape)
                )
            else:
                return mu

        # Static checks could not be run, so possibly do dynamic checks.
        if not self.validate_args:
            return mu
        else:
            assert_same_rank = check_ops.assert_equal(
                array_ops.rank(mu) + 1,
                cov.rank(),
                data=[
                    "mu should have rank 1 less than cov.  Found: rank(mu) = ",
                    array_ops.rank(mu),
                    " rank(cov) = ",
                    cov.rank(),
                ],
            )
            with ops.control_dependencies([assert_same_rank]):
                assert_same_shape = check_ops.assert_equal(
                    array_ops.shape(mu),
                    cov.vector_shape(),
                    data=[
                        "mu.shape and cov.shape[:-1] should match.  " "Found: shape(mu) = ",
                        array_ops.shape(mu),
                        " shape(cov) = ",
                        cov.shape(),
                    ],
                )
                return control_flow_ops.with_dependencies([assert_same_shape], mu)
开发者ID:damienmg,项目名称:tensorflow,代码行数:48,代码来源:mvn.py


示例8: _maybe_check_matching_sizes

  def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out,
                                  validate_args=False):
    """Check that prod(event_shape_in)==prod(event_shape_out)."""

    def _get_size_from_shape(shape):
      """Computes size from a shape `Tensor`, statically if possible."""
      s = tensor_util.constant_value(shape)
      if s is not None:
        return [np.int32(np.prod(s))]*2
      return None, math_ops.reduce_prod(shape, name="size")

    # Ensure `event_shape_in` is compatible with `event_shape_out`.
    event_size_in_, event_size_in = _get_size_from_shape(  # pylint: disable=unbalanced-tuple-unpacking
        event_shape_in)
    event_size_out_, event_size_out = _get_size_from_shape(  # pylint: disable=unbalanced-tuple-unpacking
        event_shape_out)

    assertions = []
    if event_size_in_ is not None and event_size_out_ is not None:
      if event_size_in_ != event_size_out_:
        raise ValueError(
            "Input `event_size` ({}) does not match output `event_size` ({}).".
            format(event_size_in, event_size_out_))
    elif validate_args:
      assertions.append(check_ops.assert_equal(
          event_size_in, event_size_out,
          message="Input/output `event_size`s do not match."))

    return assertions
开发者ID:SylChan,项目名称:tensorflow,代码行数:29,代码来源:reshape_impl.py


示例9: _check_labels

def _check_labels(labels, expected_labels_dimension):
  """Check labels type and shape."""
  with ops.name_scope(None, 'labels', (labels,)) as scope:
    labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    if isinstance(labels, sparse_tensor.SparseTensor):
      raise ValueError('SparseTensor labels are not supported.')
    labels_shape = array_ops.shape(labels)
    err_msg = 'labels shape must be [batch_size, {}]'.format(
        expected_labels_dimension)
    assert_rank = check_ops.assert_rank(labels, 2, message=err_msg)
    with ops.control_dependencies([assert_rank]):
      static_shape = labels.shape
      if static_shape is not None:
        dim1 = static_shape[1]
        if (dim1 is not None) and (dim1 != expected_labels_dimension):
          raise ValueError(
              'Mismatched label shape. '
              'Classifier configured with n_classes=%s.  Received %s. '
              'Suggested Fix: check your n_classes argument to the estimator '
              'and/or the shape of your label.' %
              (expected_labels_dimension, dim1))
      assert_dimension = check_ops.assert_equal(
          expected_labels_dimension, labels_shape[1], message=err_msg)
      with ops.control_dependencies([assert_dimension]):
        return array_ops.identity(labels, name=scope)
开发者ID:cneeruko,项目名称:tensorflow,代码行数:25,代码来源:head.py


示例10: assert_splits_match

def assert_splits_match(nested_splits_lists):
  """Checks that the given splits lists are identical.

  Performs static tests to ensure that the given splits lists are identical,
  and returns a list of control dependency op tensors that check that they are
  fully identical.

  Args:
    nested_splits_lists: A list of nested_splits_lists, where each split_list is
      a list of `splits` tensors from a `RaggedTensor`, ordered from outermost
      ragged dimension to innermost ragged dimension.

  Returns:
    A list of control dependency op tensors.
  Raises:
    ValueError: If the splits are not identical.
  """
  error_msg = "Inputs must have identical ragged splits"
  for splits_list in nested_splits_lists:
    if len(splits_list) != len(nested_splits_lists[0]):
      raise ValueError(error_msg)
  return [
      check_ops.assert_equal(s1, s2, message=error_msg)
      for splits_list in nested_splits_lists[1:]
      for (s1, s2) in zip(nested_splits_lists[0], splits_list)
  ]
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:26,代码来源:ragged_util.py


示例11: assert_close

def assert_close(
    x, y, data=None, summarize=None, message=None, name="assert_close"):
  """Assert that that x and y are within machine epsilon of each other.

  Args:
    x: Numeric `Tensor`
    y: Numeric `Tensor`
    data: The tensors to print out if the condition is `False`. Defaults to
      error message and first few entries of `x` and `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).

  Returns:
    Op raising `InvalidArgumentError` if |x - y| > machine epsilon.
  """
  message = message or ""
  x = ops.convert_to_tensor(x, name="x")
  y = ops.convert_to_tensor(y, name="y")

  if x.dtype.is_integer:
    return check_ops.assert_equal(
        x, y, data=data, summarize=summarize, message=message, name=name)

  with ops.name_scope(name, "assert_close", [x, y, data]):
    tol = np.finfo(x.dtype.as_numpy_dtype).resolution
    if data is None:
      data = [
          message,
          "Condition x ~= y did not hold element-wise: x = ", x.name, x, "y = ",
          y.name, y
      ]
    condition = math_ops.reduce_all(math_ops.less_equal(math_ops.abs(x-y), tol))
    return control_flow_ops.Assert(
        condition, data, summarize=summarize)
开发者ID:Nishant23,项目名称:tensorflow,代码行数:35,代码来源:distribution_util.py


示例12: __init__

  def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"):
    """Instantiates the `AbsoluteValue` bijector.

    Args:
      event_ndims: Python scalar indicating the number of dimensions associated
        with a particular draw from the distribution.  Currently only zero is
        supported.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      ValueError:  If `event_ndims` is not zero.
    """
    self._graph_parents = []
    self._name = name

    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
    event_ndims_const = tensor_util.constant_value(event_ndims)
    if event_ndims_const is not None and event_ndims_const not in (0,):
      raise ValueError("event_ndims(%s) was not 0" % event_ndims_const)
    else:
      if validate_args:
        event_ndims = control_flow_ops.with_dependencies(
            [check_ops.assert_equal(
                event_ndims, 0, message="event_ndims was not 0")],
            event_ndims)

    with self._name_scope("init"):
      super(AbsoluteValue, self).__init__(
          event_ndims=event_ndims,
          validate_args=validate_args,
          name=name)
开发者ID:Mazecreator,项目名称:tensorflow,代码行数:33,代码来源:absolute_value_impl.py


示例13: test_doesnt_raise_when_both_empty

 def test_doesnt_raise_when_both_empty(self):
   with self.test_session():
     larry = constant_op.constant([])
     curly = constant_op.constant([])
     with ops.control_dependencies([check_ops.assert_equal(larry, curly)]):
       out = array_ops.identity(larry)
     out.eval()
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例14: test_doesnt_raise_when_equal_and_broadcastable_shapes

 def test_doesnt_raise_when_equal_and_broadcastable_shapes(self):
   with self.test_session():
     small = constant_op.constant([1, 2], name="small")
     small_2 = constant_op.constant([1, 2], name="small_2")
     with ops.control_dependencies([check_ops.assert_equal(small, small_2)]):
       out = array_ops.identity(small)
     out.eval()
开发者ID:1000sprites,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例15: maybe_check_quadrature_param

def maybe_check_quadrature_param(param, name, validate_args):
  """Helper which checks validity of `loc` and `scale` init args."""
  with ops.name_scope(name="check_" + name, values=[param]):
    assertions = []
    if param.shape.ndims is not None:
      if param.shape.ndims == 0:
        raise ValueError("Mixing params must be a (batch of) vector; "
                         "{}.rank={} is not at least one.".format(
                             name, param.shape.ndims))
    elif validate_args:
      assertions.append(check_ops.assert_rank_at_least(
          param, 1,
          message=("Mixing params must be a (batch of) vector; "
                   "{}.rank is not at least one.".format(
                       name))))

    # TODO(jvdillon): Remove once we support k-mixtures.
    if param.shape.with_rank_at_least(1)[-1] is not None:
      if param.shape[-1].value != 1:
        raise NotImplementedError("Currently only bimixtures are supported; "
                                  "{}.shape[-1]={} is not 1.".format(
                                      name, param.shape[-1].value))
    elif validate_args:
      assertions.append(check_ops.assert_equal(
          array_ops.shape(param)[-1], 1,
          message=("Currently only bimixtures are supported; "
                   "{}.shape[-1] is not 1.".format(name))))

    if assertions:
      return control_flow_ops.with_dependencies(assertions, param)
    return param
开发者ID:bikong2,项目名称:tensorflow,代码行数:31,代码来源:vector_diffeomixture.py


示例16: assert_integer_form

def assert_integer_form(
    x, data=None, summarize=None, message=None,
    int_dtype=None, name="assert_integer_form"):
  """Assert that x has integer components (or floats equal to integers).

  Args:
    x: Floating-point `Tensor`
    data: The tensors to print out if the condition is `False`. Defaults to
      error message and first few entries of `x` and `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    int_dtype: A `tf.dtype` used to cast the float to. The default (`None`)
      implies the smallest possible signed int will be used for casting.
    name: A name for this operation (optional).

  Returns:
    Op raising `InvalidArgumentError` if `cast(x, int_dtype) != x`.
  """
  with ops.name_scope(name, values=[x, data]):
    x = ops.convert_to_tensor(x, name="x")
    if x.dtype.is_integer:
      return control_flow_ops.no_op()
    message = message or "{} has non-integer components".format(x.op.name)
    if int_dtype is None:
      try:
        int_dtype = {
            dtypes.float16: dtypes.int16,
            dtypes.float32: dtypes.int32,
            dtypes.float64: dtypes.int64,
        }[x.dtype.base_dtype]
      except KeyError:
        raise TypeError("Unrecognized type {}".format(x.dtype.name))
    return check_ops.assert_equal(
        x, math_ops.cast(math_ops.cast(x, int_dtype), x.dtype),
        data=data, summarize=summarize, message=message, name=name)
开发者ID:1000sprites,项目名称:tensorflow,代码行数:35,代码来源:util.py


示例17: zero_state

 def zero_state(self, batch_size, dtype):
   with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
     if self._initial_cell_state is not None:
       cell_state = self._initial_cell_state
     else:
       cell_state = self._cell.zero_state(batch_size, dtype)
     error_message = (
         "When calling zero_state of AttentionWrapper %s: " % self._base_name +
         "Non-matching batch sizes between the memory "
         "(encoder output) and the requested batch size.  Are you using "
         "the BeamSearchDecoder?  If so, make sure your encoder output has "
         "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
         "the batch_size= argument passed to zero_state is "
         "batch_size * beam_width.")
     with ops.control_dependencies(
         [check_ops.assert_equal(batch_size,
                                 self._attention_mechanism.batch_size,
                                 message=error_message)]):
       cell_state = nest.map_structure(
           lambda s: array_ops.identity(s, name="checked_cell_state"),
           cell_state)
     if self._alignment_history:
       alignment_history = tensor_array_ops.TensorArray(
           dtype=dtype, size=0, dynamic_size=True)
     else:
       alignment_history = ()
     return AttentionWrapperState(
         cell_state=cell_state,
         time=array_ops.zeros([], dtype=dtypes.int32),
         attention=_zero_state_tensors(self._attention_layer_size, batch_size,
                                       dtype),
         alignments=self._attention_mechanism.initial_alignments(
             batch_size, dtype),
         alignment_history=alignment_history)
开发者ID:ajaybhat,项目名称:tensorflow,代码行数:34,代码来源:attention_wrapper.py


示例18: _model_fn_ops

def _model_fn_ops(
    expected_features, expected_labels, actual_features, actual_labels, mode):
  assert_ops = tuple([
      check_ops.assert_equal(
          expected_features[k], actual_features[k], name='assert_%s' % k)
      for k in expected_features
  ] + [
      check_ops.assert_equal(
          expected_labels, actual_labels, name='assert_labels')
  ])
  with ops.control_dependencies(assert_ops):
    return model_fn.ModelFnOps(
        mode=mode,
        predictions=constant_op.constant(0.),
        loss=constant_op.constant(0.),
        train_op=constant_op.constant(0.))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:16,代码来源:estimator_test.py


示例19: calculate_reshape

def calculate_reshape(original_shape, new_shape, validate=False, name=None):
  """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
  batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
  if batch_shape_static.is_fully_defined():
    return np.int32(batch_shape_static.as_list()), batch_shape_static, []
  with ops.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
    original_size = math_ops.reduce_prod(original_shape)
    implicit_dim = math_ops.equal(new_shape, -1)
    size_implicit_dim = (
        original_size // math_ops.maximum(1, -math_ops.reduce_prod(new_shape)))
    new_ndims = array_ops.shape(new_shape)
    expanded_new_shape = array_ops.where(  # Assumes exactly one `-1`.
        implicit_dim, array_ops.fill(new_ndims, size_implicit_dim), new_shape)
    validations = [] if not validate else [
        check_ops.assert_rank(
            original_shape, 1, message="Original shape must be a vector."),
        check_ops.assert_rank(
            new_shape, 1, message="New shape must be a vector."),
        check_ops.assert_less_equal(
            math_ops.count_nonzero(implicit_dim, dtype=dtypes.int32),
            1,
            message="At most one dimension can be unknown."),
        check_ops.assert_positive(
            expanded_new_shape, message="Shape elements must be >=-1."),
        check_ops.assert_equal(
            math_ops.reduce_prod(expanded_new_shape),
            original_size,
            message="Shape sizes do not match."),
    ]
    return expanded_new_shape, batch_shape_static, validations
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:batch_reshape.py


示例20: _verify_input

def _verify_input(tensor_list, labels, probs_list):
  """Verify that batched inputs are well-formed."""
  checked_probs_list = []
  for probs in probs_list:
    # Since number of classes shouldn't change at runtime, probabilities shape
    # should be fully defined.
    probs.get_shape().assert_is_fully_defined()

    # Probabilities must be 1D.
    probs.get_shape().assert_has_rank(1)

    # Probabilities must be nonnegative and sum to one.
    tol = 1e-6
    prob_sum = math_ops.reduce_sum(probs)
    checked_probs = control_flow_ops.with_dependencies([
        check_ops.assert_non_negative(probs),
        check_ops.assert_less(prob_sum, 1.0 + tol),
        check_ops.assert_less(1.0 - tol, prob_sum)
    ], probs)
    checked_probs_list.append(checked_probs)

  # All probabilities should be the same length.
  prob_length = checked_probs_list[0].get_shape().num_elements()
  for checked_prob in checked_probs_list:
    if checked_prob.get_shape().num_elements() != prob_length:
      raise ValueError('Probability parameters must have the same length.')

  # Labels tensor should only have batch dimension.
  labels.get_shape().assert_has_rank(1)

  for tensor in tensor_list:
    # Data tensor should have a batch dimension.
    shape = tensor.get_shape().with_rank_at_least(1)

    # Data and label batch dimensions must be compatible.
    tensor_shape.dimension_at_index(shape, 0).assert_is_compatible_with(
        labels.get_shape()[0])

  # Data and labels must have the same, strictly positive batch size. Since we
  # can't assume we know the batch size at graph creation, add runtime checks.
  labels_batch_size = array_ops.shape(labels)[0]
  lbl_assert = check_ops.assert_positive(labels_batch_size)

  # Make each tensor depend on its own checks.
  labels = control_flow_ops.with_dependencies([lbl_assert], labels)
  tensor_list = [
      control_flow_ops.with_dependencies([
          lbl_assert,
          check_ops.assert_equal(array_ops.shape(x)[0], labels_batch_size)
      ], x) for x in tensor_list
  ]

  # Label's classes must be integers 0 <= x < num_classes.
  labels = control_flow_ops.with_dependencies([
      check_ops.assert_integer(labels), check_ops.assert_non_negative(labels),
      check_ops.assert_less(labels, math_ops.cast(prob_length, labels.dtype))
  ], labels)

  return tensor_list, labels, checked_probs_list
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:59,代码来源:sampling_ops.py



注:本文中的tensorflow.python.ops.check_ops.assert_equal函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python check_ops.assert_greater_equal函数代码示例发布时间:2022-05-27
下一篇:
Python boosted_trees_ops.training_predict函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap