• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python confusion_matrix.remove_squeezable_dimensions函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.confusion_matrix.remove_squeezable_dimensions函数的典型用法代码示例。如果您正苦于以下问题:Python remove_squeezable_dimensions函数的具体用法?Python remove_squeezable_dimensions怎么用?Python remove_squeezable_dimensions使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了remove_squeezable_dimensions函数的5个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: testSqueezablePredictionsExpectedRankDiffMinus1

  def testSqueezablePredictionsExpectedRankDiffMinus1(self):
    label_values = np.ones(shape=(2, 3, 5))
    prediction_values = np.zeros(shape=(2, 3, 1))
    static_labels, static_predictions = (
        confusion_matrix.remove_squeezable_dimensions(
            label_values, prediction_values, expected_rank_diff=-1))

    labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
    predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
    dynamic_labels, dynamic_predictions = (
        confusion_matrix.remove_squeezable_dimensions(
            labels_placeholder, predictions_placeholder, expected_rank_diff=-1))

    expected_prediction_values = np.reshape(prediction_values, newshape=(2, 3))
    with self.cached_session():
      self.assertAllEqual(label_values, static_labels.eval())
      self.assertAllEqual(expected_prediction_values, static_predictions.eval())
      feed_dict = {
          labels_placeholder: label_values,
          predictions_placeholder: prediction_values
      }
      self.assertAllEqual(
          label_values, dynamic_labels.eval(feed_dict=feed_dict))
      self.assertAllEqual(
          expected_prediction_values,
          dynamic_predictions.eval(feed_dict=feed_dict))
开发者ID:HughKu,项目名称:tensorflow,代码行数:26,代码来源:confusion_matrix_test.py


示例2: testSameShape

  def testSameShape(self):
    label_values = np.ones(shape=(2, 3, 1))
    prediction_values = np.zeros_like(label_values)
    static_labels, static_predictions = (
        confusion_matrix.remove_squeezable_dimensions(
            label_values, prediction_values))

    labels_placeholder = array_ops.placeholder(dtype=dtypes.int32)
    predictions_placeholder = array_ops.placeholder(dtype=dtypes.int32)
    dynamic_labels, dynamic_predictions = (
        confusion_matrix.remove_squeezable_dimensions(
            labels_placeholder, predictions_placeholder))

    with self.cached_session():
      self.assertAllEqual(label_values, self.evaluate(static_labels))
      self.assertAllEqual(prediction_values, self.evaluate(static_predictions))
      feed_dict = {
          labels_placeholder: label_values,
          predictions_placeholder: prediction_values
      }
      self.assertAllEqual(
          label_values, dynamic_labels.eval(feed_dict=feed_dict))
      self.assertAllEqual(
          prediction_values, dynamic_predictions.eval(feed_dict=feed_dict))
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:24,代码来源:confusion_matrix_test.py


示例3: _remove_squeezable_dimensions

def _remove_squeezable_dimensions(
    labels, predictions, weights=None, expected_rank_diff=0):
  """Internal version of _remove_squeezable_dimensions which handles weights.

  Squeezes `predictions` and `labels` if their ranks differ from expected by
  exactly 1.
  Squeezes `weights` if its rank is 1 more than the new rank of `predictions`

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    labels: Label values, a `Tensor` whose dimensions match `predictions`.
    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
    weights: Optional weight `Tensor`. It will be squeezed if it's not scalar,
      and its rank is 1 more than the new rank of `labels`.
    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.

  Returns:
    Tuple of `predictions`, `labels` and `weights`, possibly with the last
    dimension squeezed.
  """
  labels, predictions = confusion_matrix.remove_squeezable_dimensions(
      labels, predictions, expected_rank_diff=expected_rank_diff)

  if weights is not None:
    weights = ops.convert_to_tensor(weights)
    labels_rank = labels.get_shape().ndims
    weights_shape = weights.get_shape()
    weights_rank = weights_shape.ndims

    if (labels_rank is not None) and (weights_rank is not None):
      # Use static rank.
      rank_diff = weights_rank - labels_rank
      if rank_diff == 1:
        weights = array_ops.squeeze(weights, [-1])
      return labels, predictions, weights

    # Use dynamic rank.
    rank_diff = array_ops.rank(weights) - array_ops.rank(labels)
    if (weights_rank is None) or (
        weights_rank > 0 and weights_shape.dims[-1].is_compatible_with(1)):
      weights = control_flow_ops.cond(
          math_ops.equal(1, rank_diff),
          lambda: array_ops.squeeze(weights, [-1]),
          lambda: weights)

  return labels, predictions, weights
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:48,代码来源:losses_impl.py


示例4: squeeze_or_expand_dimensions

def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
  """Squeeze or expand last dimension if needed.

  1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
  (using `confusion_matrix.remove_squeezable_dimensions`).
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
  from the new rank of `y_pred`.
  If `sample_weight` is scalar, it is kept scalar.

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
    y_true: Optional label `Tensor` whose dimensions match `y_pred`.
    sample_weight: Optional weight scalar or `Tensor` whose dimensions match
      `y_pred`.

  Returns:
    Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
    the last dimension squeezed,
    `sample_weight` could be extended by one dimension.
  """
  if y_true is not None:
    # squeeze last dim of `y_pred` or `y_true` if their rank differs by 1
    y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
        y_true, y_pred)

  if sample_weight is None:
    return y_pred, y_true, None

  sample_weight = ops.convert_to_tensor(sample_weight)
  weights_shape = sample_weight.get_shape()
  weights_rank = weights_shape.ndims
  if weights_rank == 0:  # If weights is scalar, do nothing.
    return y_pred, y_true, sample_weight

  y_pred_shape = y_pred.get_shape()
  y_pred_rank = y_pred_shape.ndims
  if (y_pred_rank is not None) and (weights_rank is not None):
    # Use static rank.
    if weights_rank - y_pred_rank == 1:
      sample_weight = array_ops.squeeze(sample_weight, [-1])
    elif y_pred_rank - weights_rank == 1:
      sample_weight = array_ops.expand_dims(sample_weight, [-1])
    return y_pred, y_true, sample_weight

  # Use dynamic rank.
  weights_rank_tensor = array_ops.rank(sample_weight)
  rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
  maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])

  def _maybe_expand_weights():
    return control_flow_ops.cond(
        math_ops.equal(rank_diff,
                       -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
        lambda: sample_weight)

  def _maybe_adjust_weights():
    return control_flow_ops.cond(
        math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
        _maybe_expand_weights)

  # squeeze or expand last dim of `sample_weight` if its rank differs by 1
  # from the new rank of `y_pred`.
  sample_weight = control_flow_ops.cond(
      math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
      _maybe_adjust_weights)
  return y_pred, y_true, sample_weight
开发者ID:ThunderQi,项目名称:tensorflow,代码行数:69,代码来源:metrics.py


示例5: squeeze_or_expand_dimensions

def squeeze_or_expand_dimensions(y_pred, y_true, sample_weight):
  """Squeeze or expand last dimension if needed.

  1. Squeezes last dim of `y_pred` or `y_true` if their rank differs by 1
  (using `confusion_matrix.remove_squeezable_dimensions`).
  2. Squeezes or expands last dim of `sample_weight` if its rank differs by 1
  from the new rank of `y_pred`.
  If `sample_weight` is scalar, it is kept scalar.

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    y_pred: Predicted values, a `Tensor` of arbitrary dimensions.
    y_true: Optional label `Tensor` whose dimensions match `y_pred`.
    sample_weight: Optional weight scalar or `Tensor` whose dimensions match
      `y_pred`.

  Returns:
    Tuple of `y_pred`, `y_true` and `sample_weight`. Each of them possibly has
    the last dimension squeezed,
    `sample_weight` could be extended by one dimension.
  """
  y_pred_shape = y_pred.shape
  y_pred_rank = y_pred_shape.ndims
  if y_true is not None:

    # If sparse matrix is provided as `y_true`, the last dimension in `y_pred`
    # may be > 1. Eg: y_true = [0, 1, 2] (shape=(3,)),
    # y_pred = [[.9, .05, .05], [.5, .89, .6], [.05, .01, .94]] (shape=(3, 3))
    # In this case, we should not try to remove squeezable dimension.
    y_true_shape = y_true.shape
    y_true_rank = y_true_shape.ndims
    if (y_true_rank is not None) and (y_pred_rank is not None):
      # Use static rank for `y_true` and `y_pred`.
      if (y_pred_rank - y_true_rank != 1) or y_pred_shape[-1] == 1:
        y_true, y_pred = confusion_matrix.remove_squeezable_dimensions(
            y_true, y_pred)
    else:
      # Use dynamic rank.
      rank_diff = array_ops.rank(y_pred) - array_ops.rank(y_true)
      squeeze_dims = lambda: confusion_matrix.remove_squeezable_dimensions(  # pylint: disable=g-long-lambda
          y_true, y_pred)
      is_last_dim_1 = math_ops.equal(1, array_ops.shape(y_pred)[-1])
      maybe_squeeze_dims = lambda: control_flow_ops.cond(  # pylint: disable=g-long-lambda
          is_last_dim_1, squeeze_dims, lambda: (y_true, y_pred))
      y_true, y_pred = control_flow_ops.cond(
          math_ops.equal(1, rank_diff), maybe_squeeze_dims, squeeze_dims)

  if sample_weight is None:
    return y_pred, y_true, None

  sample_weight = ops.convert_to_tensor(sample_weight)
  weights_shape = sample_weight.shape
  weights_rank = weights_shape.ndims
  if weights_rank == 0:  # If weights is scalar, do nothing.
    return y_pred, y_true, sample_weight

  if (y_pred_rank is not None) and (weights_rank is not None):
    # Use static rank.
    if weights_rank - y_pred_rank == 1:
      sample_weight = array_ops.squeeze(sample_weight, [-1])
    elif y_pred_rank - weights_rank == 1:
      sample_weight = array_ops.expand_dims(sample_weight, [-1])
    return y_pred, y_true, sample_weight

  # Use dynamic rank.
  weights_rank_tensor = array_ops.rank(sample_weight)
  rank_diff = weights_rank_tensor - array_ops.rank(y_pred)
  maybe_squeeze_weights = lambda: array_ops.squeeze(sample_weight, [-1])

  def _maybe_expand_weights():
    return control_flow_ops.cond(
        math_ops.equal(rank_diff,
                       -1), lambda: array_ops.expand_dims(sample_weight, [-1]),
        lambda: sample_weight)

  def _maybe_adjust_weights():
    return control_flow_ops.cond(
        math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
        _maybe_expand_weights)

  # squeeze or expand last dim of `sample_weight` if its rank differs by 1
  # from the new rank of `y_pred`.
  sample_weight = control_flow_ops.cond(
      math_ops.equal(weights_rank_tensor, 0), lambda: sample_weight,
      _maybe_adjust_weights)
  return y_pred, y_true, sample_weight
开发者ID:aritratony,项目名称:tensorflow,代码行数:88,代码来源:losses_utils.py



注:本文中的tensorflow.python.ops.confusion_matrix.remove_squeezable_dimensions函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python constant_op.constant函数代码示例发布时间:2022-05-27
下一篇:
Python cond_v2.cond_v2函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap