• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python array_ops.scatter_nd函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.python.ops.array_ops.scatter_nd函数的典型用法代码示例。如果您正苦于以下问题:Python scatter_nd函数的具体用法?Python scatter_nd怎么用?Python scatter_nd使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了scatter_nd函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: __call__

  def __call__(self, shape, dtype=None, partition_info=None):
    if dtype is None:
      dtype = self.dtype
    # Check the shape
    if len(shape) < 3 or len(shape) > 5:
      raise ValueError("The tensor to initialize must be at least "
                       "three-dimensional and at most five-dimensional")

    if shape[-2] > shape[-1]:
      raise ValueError("In_filters cannot be greater than out_filters.")

    # Generate a random matrix
    a = random_ops.random_normal([shape[-1], shape[-1]],
                                 dtype=dtype, seed=self.seed)
    # Compute the qr factorization
    q, r = linalg_ops.qr(a, full_matrices=False)
    # Make Q uniform
    d = array_ops.diag_part(r)
    q *= math_ops.sign(d)
    q = q[:shape[-2], :]
    q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype))
    if len(shape) == 3:
      weight = array_ops.scatter_nd([[(shape[0]-1)//2]],
                                    array_ops.expand_dims(q, 0), shape)
    elif len(shape) == 4:
      weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2]],
                                    array_ops.expand_dims(q, 0), shape)
    else:
      weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2,
                                      (shape[2]-1)//2]],
                                    array_ops.expand_dims(q, 0), shape)
    return weight
开发者ID:moses-sun,项目名称:tensorflow,代码行数:32,代码来源:init_ops.py


示例2: testInvalidShape

 def testInvalidShape(self):
   # TODO(apassos) figure out how to unify these errors
   with self.assertRaises(errors.InvalidArgumentError
                          if context.executing_eagerly() else ValueError):
     array_ops.scatter_nd(indices=[0],  # this should be indices=[[0]]
                          updates=[0.0],
                          shape=[1])
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:7,代码来源:scatter_nd_ops_test.py


示例3: testEmptyOutputShape1

  def testEmptyOutputShape1(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = constant_op.constant([0, 3, 2], dtypes.int32)

    with self.assertRaisesWithPredicateMatch(
        ValueError, "Indices and updates specified for empty output shape"):
      array_ops.scatter_nd(indices, updates, shape)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py


示例4: _ctc_state_trans

def _ctc_state_trans(label_seq):
  """Compute CTC alignment model transition matrix.

  Args:
    label_seq: tensor of shape [batch_size, max_seq_length]

  Returns:
    tensor of shape [batch_size, states, states] with a state transition matrix
    computed for each sequence of the batch.
  """

  with ops.name_scope("ctc_state_trans"):
    label_seq = ops.convert_to_tensor(label_seq, name="label_seq")
    batch_size = _get_dim(label_seq, 0)
    num_labels = _get_dim(label_seq, 1)

    num_label_states = num_labels + 1
    num_states = 2 * num_label_states

    label_states = math_ops.range(num_label_states)
    blank_states = label_states + num_label_states

    # Start state to first label.
    start_to_label = [[1, 0]]

    # Blank to label transitions.
    blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1)

    # Label to blank transitions.
    label_to_blank = array_ops.stack([blank_states, label_states], 1)

    # Scatter transitions that don't depend on sequence.
    indices = array_ops.concat(
        [start_to_label, blank_to_label, label_to_blank], 0)
    values = array_ops.ones([_get_dim(indices, 0)])
    trans = array_ops.scatter_nd(
        indices, values, shape=[num_states, num_states])
    trans += linalg_ops.eye(num_states)  # Self-loops.

    # Label to label transitions. Disallow transitions between repeated labels
    # with no blank state in between.
    batch_idx = array_ops.zeros_like(label_states[2:])
    indices = array_ops.stack(
        [batch_idx, label_states[2:], label_states[1:-1]], 1)
    indices = array_ops.tile(
        array_ops.expand_dims(indices, 0), [batch_size, 1, 1])
    batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0]
    indices += array_ops.expand_dims(batch_idx, 1)
    repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:])
    values = 1.0 - math_ops.cast(repeats, dtypes.float32)
    batched_shape = [batch_size, num_states, num_states]
    label_to_label = array_ops.scatter_nd(indices, values, batched_shape)

    return array_ops.expand_dims(trans, 0) + label_to_label
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:54,代码来源:ctc_ops.py


示例5: testEmptyOutputShape2

  def testEmptyOutputShape2(self):
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    updates = array_ops.placeholder(dtypes.int32, shape=None)
    shape = constant_op.constant([0, 3, 2], dtypes.int32)

    with self.test_session():
      array_ops.scatter_nd(indices, updates, shape).eval(feed_dict={
          indices: np.zeros(
              [2, 2, 2], dtype=np.int32),
          updates: np.zeros(
              [2, 2, 2], dtype=np.int32)
      })
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py


示例6: testRank3InvalidShape2

  def testRank3InvalidShape2(self):
    indices = array_ops.zeros([2, 2, 1], dtypes.int32)
    updates = array_ops.zeros([2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of output\\.shape="):
      array_ops.scatter_nd(indices, updates, shape)

    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of ref\\.shape="):
      state_ops.scatter_nd_update(ref, indices, updates)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py


示例7: collapse_repeated

def collapse_repeated(labels, seq_length, name=None):
  """Merge repeated labels into single labels.

  Args:
    labels: Tensor of shape [batch, max value in seq_length]
    seq_length: Tensor of shape [batch], sequence length of each batch element.
    name: A name for this `Op`. Defaults to "collapse_repeated_labels".

  Returns:
    A tuple `(collapsed_labels, new_seq_length)` where

    collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated
    labels collapsed and padded to max_seq_length, eg:
    `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]`

    new_seq_length: int tensor of shape [batch] with new sequence lengths.
  """

  with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]):
    labels = ops.convert_to_tensor(labels, name="labels")
    seq_length = ops.convert_to_tensor(seq_length, name="seq_length")

    # Mask labels that don't equal previous label.
    label_mask = array_ops.concat([
        array_ops.ones_like(labels[:, :1], dtypes.bool),
        math_ops.not_equal(labels[:, 1:], labels[:, :-1])
    ],
                                  axis=1)

    # Filter labels that aren't in the original sequence.
    maxlen = _get_dim(labels, 1)
    seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen)
    label_mask = math_ops.logical_and(label_mask, seq_mask)

    # Count masks for new sequence lengths.
    new_seq_len = math_ops.reduce_sum(
        math_ops.cast(label_mask, dtypes.int32), axis=1)

    # Mask indexes based on sequence length mask.
    new_maxlen = math_ops.reduce_max(new_seq_len)
    idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen)

    # Flatten everything and mask out labels to keep and sparse indices.
    flat_labels = array_ops.reshape(labels, [-1])
    flat_label_mask = array_ops.reshape(label_mask, [-1])
    flat_idx_mask = array_ops.reshape(idx_mask, [-1])
    idx = math_ops.range(_get_dim(flat_idx_mask, 0))

    # Scatter to flat shape.
    flat = array_ops.scatter_nd(
        indices=array_ops.expand_dims(
            array_ops.boolean_mask(idx, flat_idx_mask), axis=1),
        updates=array_ops.boolean_mask(flat_labels, flat_label_mask),
        shape=array_ops.shape(flat_idx_mask))

    # Reshape back to square batch.
    batch_size = _get_dim(labels, 0)
    new_shape = [batch_size, new_maxlen]
    return (array_ops.reshape(flat, new_shape),
            math_ops.cast(new_seq_len, seq_length.dtype))
开发者ID:aritratony,项目名称:tensorflow,代码行数:60,代码来源:ctc_ops.py


示例8: _state_to_olabel_unique

def _state_to_olabel_unique(labels, num_labels, states, unique):
  """Sum state log probs to ilabel log probs using unique label indices."""

  num_label_states = _get_dim(labels, 1) + 1
  label_states = states[:, :, 1:num_label_states]
  blank_states = states[:, :, num_label_states:]

  unique_y, unique_idx = unique
  mul_reduce = _sum_states(unique_idx, label_states)

  num_frames = states.shape[0]
  batch_size = states.shape[1]
  num_states = num_label_states - 1
  batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0])
  batch_state_major = array_ops.reshape(
      batch_state_major, [batch_size * num_states, num_frames])
  batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels
  indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1)
  indices = array_ops.reshape(indices, [-1, 1])
  scatter = array_ops.scatter_nd(
      indices=indices,
      updates=batch_state_major,
      shape=[batch_size * num_labels, num_frames])
  scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames])
  scatter = array_ops.where(
      math_ops.equal(scatter, 0.0),
      array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)),
      scatter)
  label_olabels = array_ops.transpose(scatter, [2, 0, 1])
  label_olabels = label_olabels[:, :, 1:]

  blank_olabels = math_ops.reduce_logsumexp(
      blank_states, axis=2, keepdims=True)

  return array_ops.concat([blank_olabels, label_olabels], axis=-1)
开发者ID:adit-chandra,项目名称:tensorflow,代码行数:35,代码来源:ctc_ops.py


示例9: testScatterNdRepatedIndicesAdd

 def testScatterNdRepatedIndicesAdd(self):
   indices = array_ops.zeros([100000, 1], dtypes.int32)
   values = np.random.randn(100000)
   shape = [1]
   with self.test_session():
     val = array_ops.scatter_nd(indices, values, shape).eval()
   self.assertAllClose([np.sum(values)], val)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:7,代码来源:scatter_nd_ops_test.py


示例10: maybe_sample

 def maybe_sample():
   """Perform scheduled sampling."""
   where_sampling = math_ops.cast(
       array_ops.where(sample_ids > -1), dtypes.int32)
   where_not_sampling = math_ops.cast(
       array_ops.where(sample_ids <= -1), dtypes.int32)
   sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
   inputs_not_sampling = array_ops.gather_nd(
       base_next_inputs, where_not_sampling)
   sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
   base_shape = array_ops.shape(base_next_inputs)
   return (array_ops.scatter_nd(indices=where_sampling,
                                updates=sampled_next_inputs,
                                shape=base_shape)
           + array_ops.scatter_nd(indices=where_not_sampling,
                                  updates=inputs_not_sampling,
                                  shape=base_shape))
开发者ID:AnddyWang,项目名称:tensorflow,代码行数:17,代码来源:helper.py


示例11: testEmptyOutputShape3

  def testEmptyOutputShape3(self):
    indices = array_ops.zeros([0], dtypes.int32)
    updates = array_ops.zeros([0], dtypes.int32)
    shape = constant_op.constant([0], dtypes.int32)
    scatter = array_ops.scatter_nd(indices, updates, shape)

    with self.test_session():
      self.assertEqual(scatter.eval().size, 0)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:8,代码来源:scatter_nd_ops_test.py


示例12: _apply_sparse_shared

 def _apply_sparse_shared(self, grad_values, grad_indices, var):
   if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0:
     # The dimension is small enough, we can make the variable dense and
     # do a dense update
     dense_grad = array_ops.scatter_nd(
         array_ops.expand_dims(grad_indices, axis=1), grad_values,
         array_ops.shape(var, out_type=grad_indices.dtype))
     return self._apply_gradient(dense_grad, var)
   return self._apply_gradient(grad_values, var, grad_indices)
开发者ID:ZhangXinNan,项目名称:tensorflow,代码行数:9,代码来源:shampoo.py


示例13: _runScatterNd

 def _runScatterNd(self, indices, updates, shape):
   with self.test_session():
     updates_placeholder = array_ops.placeholder(updates.dtype)
     indices_placeholder = array_ops.placeholder(indices.dtype)
     with self.test_scope():
       output = array_ops.scatter_nd(indices_placeholder, updates_placeholder,
                                     shape)
     feed_dict = {updates_placeholder: updates, indices_placeholder: indices}
     return output.eval(feed_dict=feed_dict)
开发者ID:AndrewTwinz,项目名称:tensorflow,代码行数:9,代码来源:scatter_nd_op_test.py


示例14: _GatherNdGrad

def _GatherNdGrad(op, grad):
  ref = op.inputs[0]
  indices = op.inputs[1]
  ref_shape = array_ops.shape(ref, out_type=indices.dtype)
  if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1:
    ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1),
                                 ref_shape)
  else:
    ref_grad = array_ops.scatter_nd(indices, grad, ref_shape)
  return [ref_grad, None]
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:array_grad.py


示例15: maybe_sample

      def maybe_sample():
        """Perform scheduled sampling."""
        if self._next_input_layer is None:
          return array_ops.where(sample_ids, outputs, base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = self._next_input_layer(outputs_sampling)
        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))
开发者ID:Immexxx,项目名称:tensorflow,代码行数:20,代码来源:helper.py


示例16: testGradientsRank2SliceUpdate

  def testGradientsRank2SliceUpdate(self):
    indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    shape = constant_op.constant([2, 2], dtype=dtypes.int32)
    outputs = array_ops.scatter_nd(indices, updates, shape)

    grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0]
    expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_grads, grads.eval())
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:11,代码来源:scatter_nd_ops_test.py


示例17: scheduled_sampling

  def scheduled_sampling(self, batch_size, sampling_probability, true, estimate):
    with variable_scope.variable_scope("ScheduledEmbedding"):
      # Return -1s where we do not sample, and sample_ids elsewhere
      select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool)
      select_sample = select_sampler.sample(sample_shape=batch_size)
      sample_ids = array_ops.where(
                  select_sample,
                  tf.range(batch_size),
                  gen_array_ops.fill([batch_size], -1))
      where_sampling = math_ops.cast(
          array_ops.where(sample_ids > -1), tf.int32)
      where_not_sampling = math_ops.cast(
          array_ops.where(sample_ids <= -1), tf.int32)
      _estimate = array_ops.gather_nd(estimate, where_sampling)
      _true = array_ops.gather_nd(true, where_not_sampling)

      base_shape = array_ops.shape(true)
      result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape)
      result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape)
      result = result1 + result2
      return result1 + result2
开发者ID:sra4077,项目名称:RLSeq2Seq,代码行数:21,代码来源:run_summarization.py


示例18: testRank3ValidShape

  def testRank3ValidShape(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    self.assertAllEqual(
        array_ops.scatter_nd(indices, updates, shape).get_shape().as_list(),
        shape)

    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    self.assertAllEqual(
        state_ops.scatter_nd_update(ref, indices,
                                    updates).get_shape().as_list(), shape)
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:12,代码来源:scatter_nd_ops_test.py


示例19: maybe_sample

      def maybe_sample():
        """Perform scheduled sampling."""

        def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
          """Concatenate outputs with auxiliary inputs, if they exist."""
          if self._auxiliary_input_tas is None:
            return outputs_

          next_time = time + 1
          auxiliary_inputs = nest.map_structure(
              lambda ta: ta.read(next_time), self._auxiliary_input_tas)
          if indices is not None:
            auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
          return nest.map_structure(
              lambda x, y: array_ops.concat((x, y), -1),
              outputs_, auxiliary_inputs)

        if self._next_input_layer is None:
          return array_ops.where(
              sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
              base_next_inputs)

        where_sampling = math_ops.cast(
            array_ops.where(sample_ids), dtypes.int32)
        where_not_sampling = math_ops.cast(
            array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
        outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
        inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
                                                  where_not_sampling)
        sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
            self._next_input_layer(outputs_sampling), where_sampling)

        base_shape = array_ops.shape(base_next_inputs)
        return (array_ops.scatter_nd(indices=where_sampling,
                                     updates=sampled_next_inputs,
                                     shape=base_shape)
                + array_ops.scatter_nd(indices=where_not_sampling,
                                       updates=inputs_not_sampling,
                                       shape=base_shape))
开发者ID:AlbertXiebnu,项目名称:tensorflow,代码行数:39,代码来源:helper.py


示例20: testExtraIndicesDimensions

  def testExtraIndicesDimensions(self):
    indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    updates = array_ops.zeros([1, 1], dtypes.int32)
    shape = np.array([2, 2])
    scatter = array_ops.scatter_nd(indices, updates, shape)
    self.assertAllEqual(scatter.get_shape().as_list(), shape)
    expected_result = np.zeros([2, 2], dtype=np.int32)
    with self.test_session():
      self.assertAllEqual(expected_result, scatter.eval())

    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
    self.assertAllEqual(scatter_update.get_shape().as_list(), shape)

    with self.test_session():
      ref.initializer.run()
      self.assertAllEqual(expected_result, scatter_update.eval())
开发者ID:AliMiraftab,项目名称:tensorflow,代码行数:17,代码来源:scatter_nd_ops_test.py



注:本文中的tensorflow.python.ops.array_ops.scatter_nd函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python array_ops.sequence_mask函数代码示例发布时间:2022-05-27
下一篇:
Python array_ops.reverse_v2函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap