• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python array_ops.squeeze函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.array_ops.squeeze函数的典型用法代码示例。如果您正苦于以下问题:Python squeeze函数的具体用法?Python squeeze怎么用?Python squeeze使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了squeeze函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: GetParams

 def GetParams(self):
   """Create a graph containing multiple segment."""
   input_name = "input"
   input_dims = [2, 32, 32, 3]
   g = ops.Graph()
   with g.as_default():
     inp = array_ops.placeholder(
         dtype=dtypes.float32, shape=input_dims, name=input_name)
     with g.device("/GPU:0"):
       n = inp
       c = constant_op.constant(1.0, name="c")
       n = math_ops.add(n, c, name="add")
       n = math_ops.mul(n, n, name="mul")
       n = math_ops.add(n, n, name="add1")
       n = self.trt_incompatible_op(n, name="incompatible1")
       n = math_ops.add(n, c, name="add2")
       n = math_ops.mul(n, n, name="mul1")
       n = math_ops.add(n, n, name="add3")
     array_ops.squeeze(n, name=self.output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       expected_engines={
           "my_trt_op_0": ["add2", "add3", "mul1"],
           # Why segment ["add", "add1", "mul"] was assigned segment id 1
           # instead of 0: the parent node of this segment is actually const
           # node 'c', but it's removed later since it's const output of the
           # segment which is not allowed.
           "my_trt_op_1": ["add", "add1", "mul"]
       },
       expected_output_dims=tuple(input_dims),
       allclose_atol=1.e-06,
       allclose_rtol=1.e-06)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:34,代码来源:base_test.py


示例2: GetParams

 def GetParams(self):
   """Create a graph containing multiple segment."""
   input_name = "input"
   input_dims = [2, 32, 32, 3]
   output_name = "output"
   g = ops.Graph()
   with g.as_default():
     inp = array_ops.placeholder(
         dtype=dtypes.float32, shape=input_dims, name=input_name)
     with g.device("/GPU:0"):
       n = inp
       c = constant_op.constant(1.0, name="c")
       # Adds control dependency from the constant op to a trt incompatible op,
       # and adds control dependency from the trt incompatible op to all other
       # ops, to make sure the constant op cannot be contracted with any trt
       # segment that depends on it.
       with g.control_dependencies([c]):
         d = self.trt_incompatible_op(n, name="incompatible")
       with g.control_dependencies([d]):
         n = math_ops.add(n, c, name="add")
         n = math_ops.mul(n, n, name="mul")
         n = math_ops.add(n, n, name="add1")
       n = self.trt_incompatible_op(n, name="incompatible1")
       with g.control_dependencies([d]):
         n = math_ops.add(n, c, name="add2")
         n = math_ops.mul(n, n, name="mul1")
         n = math_ops.add(n, n, name="add3")
     array_ops.squeeze(n, name=output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       output_names=[output_name],
       expected_output_dims=[tuple(input_dims)])
开发者ID:aeverall,项目名称:tensorflow,代码行数:34,代码来源:base_test.py


示例3: _test_squeeze

def _test_squeeze(data, squeeze_dims=None):
    """ One iteration of squeeze """

    if squeeze_dims is None:
        squeeze_dims = []

    # see relay/frontend/tflite.py convert_squeeze more detail of channel first rule
    if len(data.shape) == 1 or len(data.shape) == 2:
        tvm_data = data
    elif len(data.shape) == 3:
        tvm_data = np.transpose(data, axes=(0, 2, 1))
    elif len(data.shape) == 4:
        tvm_data = np.transpose(data, axes=(0, 3, 1, 2))
    else:
        raise NotImplementedError("Not support input shape {} of reshape : ".
                                  format(str(len(data.shape))))

    tvm_data = np.transpose(data, axes=(0, 3, 1, 2))

    with tf.Graph().as_default():
        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)

        if squeeze_dims:
            out = array_ops.squeeze(in_data, squeeze_dims)
        else:
            out = array_ops.squeeze(in_data)

        compare_tflite_with_tvm(data, tvm_data, 'Placeholder:0', [in_data], [out])
开发者ID:bddppq,项目名称:tvm,代码行数:28,代码来源:test_forward.py


示例4: call

  def call(self, inputs):
    # There is no TF op for 1D pooling, hence we make the inputs 4D.
    if self.data_format == 'channels_last':
      # input is NWC, make it NHWC
      inputs = array_ops.expand_dims(inputs, 1)
      # pool on the W dim
      pool_shape = (1, 1) + self.pool_size + (1,)
      strides = (1, 1) + self.strides + (1,)
      data_format = 'NHWC'
    else:
      # input is NCW, make it NCHW
      inputs = array_ops.expand_dims(inputs, 2)
      # pool on the W dim
      pool_shape = (1, 1, 1) + self.pool_size
      strides = (1, 1, 1) + self.strides
      data_format = 'NCHW'

    outputs = self.pool_function(
        inputs,
        ksize=pool_shape,
        strides=strides,
        padding=self.padding.upper(),
        data_format=data_format)

    if self.data_format == 'channels_last':
      return array_ops.squeeze(outputs, 1)
    else:
      return array_ops.squeeze(outputs, 2)
开发者ID:ChengYuXiang,项目名称:tensorflow,代码行数:28,代码来源:pooling.py


示例5: _statistics

def _statistics(x, axes):
  """Calculate the mean and mean square of `x`.

  Modified from the implementation of `tf.nn.moments`.

  Args:
    x: A `Tensor`.
    axes: Array of ints.  Axes along which to compute mean and
      variance.

  Returns:
    Two `Tensor` objects: `mean` and `square mean`.
  """
  # The dynamic range of fp16 is too limited to support the collection of
  # sufficient statistics. As a workaround we simply perform the operations
  # on 32-bit floats before converting the mean and variance back to fp16
  y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x

  # Compute true mean while keeping the dims for proper broadcasting.
  shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))

  shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
  mean = shifted_mean + shift
  mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)

  mean = array_ops.squeeze(mean, axes)
  mean_squared = array_ops.squeeze(mean_squared, axes)
  if x.dtype == dtypes.float16:
    return (math_ops.cast(mean, dtypes.float16),
            math_ops.cast(mean_squared, dtypes.float16))
  else:
    return (mean, mean_squared)
开发者ID:AnishShah,项目名称:tensorflow,代码行数:32,代码来源:virtual_batchnorm_impl.py


示例6: GetParams

 def GetParams(self):
   """Test for rank 2 input in TF-TRT."""
   input_names = ["input", "input2"]
   # Two paths: first with rank 2 input, second with rank 4 input.
   input_dims = [[12, 5], [12, 5, 2, 2]]
   output_name = "output"
   g = ops.Graph()
   with g.as_default():
     outputs = []
     for i in range(2):
       x = array_ops.placeholder(
           dtype=dtypes.float32, shape=input_dims[i], name=input_names[i])
       c = constant_op.constant(1.0, name="c%d_1" % i)
       q = math_ops.add(x, c, name="add%d_1" % i)
       q = math_ops.abs(q, name="abs%d_1" % i)
       c = constant_op.constant(2.2, name="c%d_2" % i)
       q = math_ops.add(q, c, name="add%d_2" % i)
       q = math_ops.abs(q, name="abs%d_2" % i)
       c = constant_op.constant(3.0, name="c%d_3" % i)
       q = math_ops.add(q, c, name="add%d_3" % i)
       if i == 0:
         for j in range(2):
           q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j))
       q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i)
       outputs.append(q)
     # Combine both paths
     q = math_ops.add(outputs[0], outputs[1], name="add")
     array_ops.squeeze(q, name=output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=input_names,
       input_dims=input_dims,
       output_names=[output_name],
       expected_output_dims=[tuple(input_dims[1])])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:34,代码来源:rank_two_test.py


示例7: get_simple_graph_def

 def get_simple_graph_def(self):
   """Create a simple graph and return its graph_def."""
   g = ops.Graph()
   with g.as_default():
     a = aops.placeholder(
         dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
     e = cop.constant(
         [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
         name="weights",
         dtype=dtypes.float32)
     conv = nn.conv2d(
         input=a,
         filter=e,
         strides=[1, 2, 2, 1],
         padding="SAME",
         name="conv")
     b = cop.constant(
         [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
     t = nn.bias_add(conv, b, name="biasAdd")
     relu = nn.relu(t, "relu")
     idty = aops.identity(relu, "ID")
     v = nn_ops.max_pool(
         idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
     aops.squeeze(v, name="output")
   return g.as_graph_def()
开发者ID:ebrevdo,项目名称:tensorflow,代码行数:25,代码来源:tf_trt_integration_test.py


示例8: testSqueezeMatrix

  def testSqueezeMatrix(self):
    matrix = [[1, 2, 3]]
    matrix_squeezed = array_ops.squeeze(matrix, [0])
    self.assertEqual(matrix_squeezed.get_shape(), (3))

    with self.assertRaises(ValueError):
      matrix_squeezed = array_ops.squeeze(matrix, [1])
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:7,代码来源:array_ops_test.py


示例9: GetParams

 def GetParams(self):
   """Neighboring node wiring tests in TF-TRT conversion."""
   dtype = dtypes.float32
   input_name = "input"
   input_dims = [2, 3, 7, 5]
   output_name = "output"
   g = ops.Graph()
   with g.as_default():
     x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
     e = constant_op.constant(
         np.random.normal(.3, 0.05, [3, 2, 3, 4]), name="weights", dtype=dtype)
     conv = nn.conv2d(
         input=x,
         filter=e,
         data_format="NCHW",
         strides=[1, 1, 1, 1],
         padding="VALID",
         name="conv")
     b = constant_op.constant(
         np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
     t = math_ops.mul(conv, b, name="mul")
     e = self.trt_incompatible_op(conv, name="incompatible")
     t = math_ops.sub(t, e, name="sub")
     array_ops.squeeze(t, name=output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       output_names=[output_name],
       expected_output_dims=[(2, 4, 5, 4)])
开发者ID:Ajaycs99,项目名称:tensorflow,代码行数:30,代码来源:neighboring_engine_test.py


示例10: test_virtual_statistics

  def test_virtual_statistics(self):
    """Check that `_virtual_statistics` gives same result as `nn.moments`."""
    random_seed.set_random_seed(1234)

    batch_axis = 0
    partial_batch = random_ops.random_normal([4, 5, 7, 3])
    single_example = random_ops.random_normal([1, 5, 7, 3])
    full_batch = array_ops.concat([partial_batch, single_example], axis=0)

    for reduction_axis in range(1, 4):
      # Get `nn.moments` on the full batch.
      reduction_axes = list(range(4))
      del reduction_axes[reduction_axis]
      mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)

      # Get virtual batch statistics.
      vb_reduction_axes = list(range(4))
      del vb_reduction_axes[reduction_axis]
      del vb_reduction_axes[batch_axis]
      vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
      vb_mean, mean_sq = vbn._virtual_statistics(
          single_example, vb_reduction_axes)
      vb_variance = mean_sq - math_ops.square(vb_mean)
      # Remove singleton batch dim for easy comparisons.
      vb_mean = array_ops.squeeze(vb_mean, batch_axis)
      vb_variance = array_ops.squeeze(vb_variance, batch_axis)

      with self.cached_session(use_gpu=True) as sess:
        vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
            vb_mean, vb_variance, mom_mean, mom_variance])

      self.assertAllClose(mom_mean_np, vb_mean_np)
      self.assertAllClose(mom_var_np, vb_var_np)
开发者ID:Albert-Z-Guo,项目名称:tensorflow,代码行数:33,代码来源:virtual_batchnorm_test.py


示例11: GetParams

 def GetParams(self):
   """Neighboring node wiring tests in TF-TRT conversion."""
   dtype = dtypes.float32
   input_name = "input"
   input_dims = [2, 3, 7, 5]
   g = ops.Graph()
   with g.as_default():
     x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
     e = constant_op.constant(
         np.random.normal(.3, 0.05, [3, 2, 3, 4]), name="weights", dtype=dtype)
     conv = nn.conv2d(
         input=x,
         filter=e,
         data_format="NCHW",
         strides=[1, 1, 1, 1],
         padding="VALID",
         name="conv")
     b = constant_op.constant(
         np.random.normal(1.0, 1.0, [1, 4, 1, 1]), name="bias", dtype=dtype)
     t = conv * b
     e = gen_math_ops.tan(conv)
     t = t - e
     array_ops.squeeze(t, name=self.output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       num_expected_engines=2,
       expected_output_dims=(2, 4, 5, 4),
       allclose_atol=1.e-03,
       allclose_rtol=1.e-03)
开发者ID:StephenOman,项目名称:tensorflow,代码行数:31,代码来源:neighboring_engine_test.py


示例12: GetParams

 def GetParams(self):
   """Single vgg layer test in TF-TRT conversion."""
   dtype = dtypes.float32
   input_name = "input"
   input_dims = [5, 8, 8, 2]
   output_name = "output"
   g = ops.Graph()
   with g.as_default():
     x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name)
     x, _, _ = nn_impl.fused_batch_norm(
         x, [1.0, 1.0], [0.0, 0.0],
         mean=[0.5, 0.5],
         variance=[1.0, 1.0],
         is_training=False)
     e = constant_op.constant(
         np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype)
     conv = nn.conv2d(
         input=x, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
     b = constant_op.constant(np.random.randn(6), name="bias", dtype=dtype)
     t = nn.bias_add(conv, b, name="biasAdd")
     relu = nn.relu(t, "relu")
     idty = array_ops.identity(relu, "ID")
     v = nn_ops.max_pool(
         idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
     array_ops.squeeze(v, name=output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       output_names=[output_name],
       expected_output_dims=[(5, 2, 2, 6)])
开发者ID:aeverall,项目名称:tensorflow,代码行数:31,代码来源:vgg_block_test.py


示例13: GetMultiEngineGraphDef

def GetMultiEngineGraphDef(dtype=dtypes.float32):
  """Create a graph containing multiple segment."""
  g = ops.Graph()
  with g.as_default():
    inp = array_ops.placeholder(
        dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
    with g.device("/GPU:0"):
      conv_filter = constant_op.constant(
          [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
          name="weights",
          dtype=dtype)
      conv = nn.conv2d(
          input=inp,
          filter=conv_filter,
          strides=[1, 2, 2, 1],
          padding="SAME",
          name="conv")
      c1 = constant_op.constant(
          np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
      p = conv * c1
      c2 = constant_op.constant(
          np.random.randn(INPUT_DIMS[0], 12, 12, 6), dtype=dtype)
      q = conv / c2

      edge = math_ops.sin(q)
      edge /= edge
      r = edge + edge

      p -= edge
      q *= edge
      s = p + q
      s -= r
    array_ops.squeeze(s, name=OUTPUT_NAME)
  return g.as_graph_def()
开发者ID:Eagle732,项目名称:tensorflow,代码行数:34,代码来源:tf_trt_integration_test.py


示例14: GetParams

 def GetParams(self):
   """Testing conversion of BatchMatMul in TF-TRT conversion."""
   dtype = dtypes.float32
   input_name = "input"
   input_dims = [2, 15, 15, 3]
   g = ops.Graph()
   with g.as_default():
     inp = array_ops.placeholder(
         dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
     with g.device("/GPU:0"):
       e1 = constant_op.constant(
           np.random.randn(1, 1, 3, 5), name="kernel_1", dtype=dtype)
       e2 = constant_op.constant(
           np.random.randn(1, 1, 5, 10), name="kernel_2", dtype=dtype)
       conv = nn.conv2d(
           input=inp,
           filter=e1,
           strides=[1, 1, 1, 1],
           padding="VALID",
           name="conv")
       out = nn.conv2d(
           input=conv,
           filter=e2,
           strides=[1, 1, 1, 1],
           padding="VALID",
           name="conv_2")
     array_ops.squeeze(out, name=self.output_name)
   return trt_test.TfTrtIntegrationTestParams(
       gdef=g.as_graph_def(),
       input_names=[input_name],
       input_dims=[input_dims],
       expected_engines=["my_trt_op_0"],
       expected_output_dims=(2, 15, 15, 10),
       allclose_atol=1.e-02,
       allclose_rtol=1.e-02)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:35,代码来源:memory_alignment_test.py


示例15: GetSingleEngineGraphDef

def GetSingleEngineGraphDef(dtype=dtypes.float32):
  """Create a graph containing single segment."""
  g = ops.Graph()
  with g.as_default():
    inp = array_ops.placeholder(
        dtype=dtype, shape=[None] + INPUT_DIMS[1:], name=INPUT_NAME)
    with g.device("/GPU:0"):
      conv_filter = constant_op.constant(
          [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
          name="weights",
          dtype=dtype)
      conv = nn.conv2d(
          input=inp,
          filter=conv_filter,
          strides=[1, 2, 2, 1],
          padding="SAME",
          name="conv")
      bias = constant_op.constant(
          [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtype)
      added = nn.bias_add(conv, bias, name="bias_add")
      relu = nn.relu(added, "relu")
      identity = array_ops.identity(relu, "identity")
      pool = nn_ops.max_pool(
          identity, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
    array_ops.squeeze(pool, name=OUTPUT_NAME)
  return g.as_graph_def()
开发者ID:Eagle732,项目名称:tensorflow,代码行数:26,代码来源:tf_trt_integration_test.py


示例16: average_impurity

  def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

    Returns:
      The last op in the graph.
    """
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
                                                 squeeze_dims=[1]))
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big)
开发者ID:AutumnQYN,项目名称:tensorflow,代码行数:25,代码来源:tensor_forest.py


示例17: testSqueezeMatrix

  def testSqueezeMatrix(self):
    matrix = [[1, 2, 3]]
    matrix_squeezed = array_ops.squeeze(matrix, [0])
    self.assertEqual(matrix_squeezed.get_shape(), (3))

    with self.assertRaisesRegexp(
        Exception, "Can not squeeze dim.1., expected a dimension of 1, got 3"):
      matrix_squeezed = array_ops.squeeze(matrix, [1])
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:8,代码来源:array_ops_test.py


示例18: _recall_at_threshold

def _recall_at_threshold(labels, predictions, weights, threshold, name=None):
  with ops.name_scope(
      name, 'recall_at_%s' % threshold,
      (predictions, labels, weights, threshold)) as scope:
    precision_tensor, update_op = metrics_lib.recall_at_thresholds(
        labels=labels, predictions=predictions, thresholds=(threshold,),
        weights=weights, name=scope)
    return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op)
开发者ID:vaccine,项目名称:tensorflow,代码行数:8,代码来源:head.py


示例19: remove_squeezable_dimensions

def remove_squeezable_dimensions(
    labels, predictions, expected_rank_diff=0, name=None):
  """Squeeze last dim if ranks differ from expected by exactly 1.

  In the common case where we expect shapes to match, `expected_rank_diff`
  defaults to 0, and we squeeze the last dimension of the larger rank if they
  differ by 1.

  But, for example, if `labels` contains class IDs and `predictions` contains 1
  probability per class, we expect `predictions` to have 1 more dimension than
  `labels`, so `expected_rank_diff` would be 1. In this case, we'd squeeze
  `labels` if `rank(predictions) - rank(labels) == 0`, and
  `predictions` if `rank(predictions) - rank(labels) == 2`.

  This will use static shape if available. Otherwise, it will add graph
  operations, which could result in a performance hit.

  Args:
    labels: Label values, a `Tensor` whose dimensions match `predictions`.
    predictions: Predicted values, a `Tensor` of arbitrary dimensions.
    expected_rank_diff: Expected result of `rank(predictions) - rank(labels)`.
    name: Name of the op.

  Returns:
    Tuple of `labels` and `predictions`, possibly with last dim squeezed.
  """
  with ops.name_scope(name, 'remove_squeezable_dimensions',
                      [labels, predictions]):
    predictions = ops.convert_to_tensor(predictions)
    labels = ops.convert_to_tensor(labels)
    predictions_shape = predictions.get_shape()
    predictions_rank = predictions_shape.ndims
    labels_shape = labels.get_shape()
    labels_rank = labels_shape.ndims
    if (labels_rank is not None) and (predictions_rank is not None):
      # Use static rank.
      rank_diff = predictions_rank - labels_rank
      if rank_diff == expected_rank_diff + 1:
        predictions = array_ops.squeeze(predictions, [-1])
      elif rank_diff == expected_rank_diff - 1:
        labels = array_ops.squeeze(labels, [-1])
      return labels, predictions

    # Use dynamic rank.
    rank_diff = array_ops.rank(predictions) - array_ops.rank(labels)
    if (predictions_rank is None) or (
        predictions_shape.dims[-1].is_compatible_with(1)):
      predictions = control_flow_ops.cond(
          math_ops.equal(expected_rank_diff + 1, rank_diff),
          lambda: array_ops.squeeze(predictions, [-1]),
          lambda: predictions)
    if (labels_rank is None) or (
        labels_shape.dims[-1].is_compatible_with(1)):
      labels = control_flow_ops.cond(
          math_ops.equal(expected_rank_diff - 1, rank_diff),
          lambda: array_ops.squeeze(labels, [-1]),
          lambda: labels)
    return labels, predictions
开发者ID:aritratony,项目名称:tensorflow,代码行数:58,代码来源:confusion_matrix.py


示例20: crf_decode

def crf_decode(potentials, transition_params, sequence_length):
  """Decode the highest scoring sequence of tags in TensorFlow.

  This is a function for tensor.

  Args:
    potentials: A [batch_size, max_seq_len, num_tags] tensor of
              unary potentials.
    transition_params: A [num_tags, num_tags] matrix of
              binary potentials.
    sequence_length: A [batch_size] vector of true sequence lengths.

  Returns:
    decode_tags: A [batch_size, max_seq_len] matrix, with dtype `tf.int32`.
                Contains the highest scoring tag indicies.
    best_score: A [batch_size] vector, containing the score of `decode_tags`.
  """
  # For simplicity, in shape comments, denote:
  # 'batch_size' by 'B', 'max_seq_len' by 'T' , 'num_tags' by 'O' (output).
  num_tags = potentials.get_shape()[2].value

  # Computes forward decoding. Get last score and backpointers.
  crf_fwd_cell = CrfDecodeForwardRnnCell(transition_params)
  initial_state = array_ops.slice(potentials, [0, 0, 0], [-1, 1, -1])
  initial_state = array_ops.squeeze(initial_state, axis=[1])      # [B, O]
  inputs = array_ops.slice(potentials, [0, 1, 0], [-1, -1, -1])   # [B, T-1, O]
  backpointers, last_score = rnn.dynamic_rnn(
      crf_fwd_cell,
      inputs=inputs,
      sequence_length=sequence_length - 1,
      initial_state=initial_state,
      time_major=False,
      dtype=dtypes.int32)             # [B, T - 1, O], [B, O]
  backpointers = gen_array_ops.reverse_sequence(
      backpointers, sequence_length - 1, seq_dim=1)               # [B, T-1, O]

  # Computes backward decoding. Extract tag indices from backpointers.
  crf_bwd_cell = CrfDecodeBackwardRnnCell(num_tags)
  initial_state = math_ops.cast(math_ops.argmax(last_score, axis=1),
                                dtype=dtypes.int32)               # [B]
  initial_state = array_ops.expand_dims(initial_state, axis=-1)   # [B, 1]
  decode_tags, _ = rnn.dynamic_rnn(
      crf_bwd_cell,
      inputs=backpointers,
      sequence_length=sequence_length - 1,
      initial_state=initial_state,
      time_major=False,
      dtype=dtypes.int32)           # [B, T - 1, 1]
  decode_tags = array_ops.squeeze(decode_tags, axis=[2])           # [B, T - 1]
  decode_tags = array_ops.concat([initial_state, decode_tags], axis=1)  # [B, T]
  decode_tags = gen_array_ops.reverse_sequence(
      decode_tags, sequence_length, seq_dim=1)                     # [B, T]

  best_score = math_ops.reduce_max(last_score, axis=1)             # [B]
  return decode_tags, best_score
开发者ID:SylChan,项目名称:tensorflow,代码行数:55,代码来源:crf.py



注:本文中的tensorflow.python.ops.array_ops.squeeze函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python array_ops.stack函数代码示例发布时间:2022-05-27
下一篇:
Python array_ops.split函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap