• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Python tensorflow.assert_rank函数代码示例

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

本文整理汇总了Python中tensorflow.assert_rank函数的典型用法代码示例。如果您正苦于以下问题:Python assert_rank函数的具体用法?Python assert_rank怎么用?Python assert_rank使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。



在下文中一共展示了assert_rank函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。

示例1: test_model_inputs

def test_model_inputs(model_inputs):
    with tf.Graph().as_default():
        input_data, targets, lr, keep_prob = model_inputs()

        # Check type
        assert input_data.op.type == 'Placeholder',\
            'Input is not a Placeholder.'
        assert targets.op.type == 'Placeholder',\
            'Targets is not a Placeholder.'
        assert lr.op.type == 'Placeholder',\
            'Learning Rate is not a Placeholder.'
        assert keep_prob.op.type == 'Placeholder', \
            'Keep Probability is not a Placeholder.'

        # Check name
        assert input_data.name == 'input:0',\
            'Input has bad name.  Found name {}'.format(input_data.name)
        assert keep_prob.name == 'keep_prob:0', \
            'Keep Probability has bad name.  Found name {}'.format(keep_prob.name)

        assert tf.assert_rank(input_data, 2, message='Input data has wrong rank')
        assert tf.assert_rank(targets, 2, message='Targets has wrong rank')
        assert tf.assert_rank(lr, 0, message='Learning Rate has wrong rank')
        assert tf.assert_rank(keep_prob, 0, message='Keep Probability has wrong rank')

    _print_success_message()
开发者ID:DavidWhois,项目名称:English_to_French_translation,代码行数:26,代码来源:problem_unittests.py


示例2: calculate_reshape

def calculate_reshape(original_shape, new_shape, validate=False, name=None):
  """Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
  batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
  if batch_shape_static.is_fully_defined():
    return np.int32(batch_shape_static.as_list()), batch_shape_static, []
  with tf.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
    original_size = tf.reduce_prod(original_shape)
    implicit_dim = tf.equal(new_shape, -1)
    size_implicit_dim = (
        original_size // tf.maximum(1, -tf.reduce_prod(new_shape)))
    new_ndims = tf.shape(new_shape)
    expanded_new_shape = tf.where(  # Assumes exactly one `-1`.
        implicit_dim, tf.fill(new_ndims, size_implicit_dim), new_shape)
    validations = [] if not validate else [
        tf.assert_rank(
            original_shape, 1, message="Original shape must be a vector."),
        tf.assert_rank(new_shape, 1, message="New shape must be a vector."),
        tf.assert_less_equal(
            tf.count_nonzero(implicit_dim, dtype=tf.int32),
            1,
            message="At most one dimension can be unknown."),
        tf.assert_positive(
            expanded_new_shape, message="Shape elements must be >=-1."),
        tf.assert_equal(
            tf.reduce_prod(expanded_new_shape),
            original_size,
            message="Shape sizes do not match."),
    ]
    return expanded_new_shape, batch_shape_static, validations
开发者ID:lewisKit,项目名称:probability,代码行数:29,代码来源:batch_reshape.py


示例3: infer

def infer(encoder_cell, decoder_cell, sentences):
    tf.assert_rank(sentences, 3)
    assert sentences.get_shape()[0].value == 1  # batch size
    assert sentences.get_shape()[2].value == FEATURE_SIZE

    # stops generating output if the length reaches the double of the source
    output_len_threshold = sentences.get_shape()[1].value * 2

    final_state_tuple = encode(sentences, encoder_cell, reuse=True)
    context = bridge(final_state_tuple.c, decoder_cell.output_size, reuse=True)

    with tf.variable_scope('decoder', reuse=True):
        def cond(loop_cnt, prev_out, _, __):
            less = tf.less(loop_cnt, output_len_threshold)
            is_regular_word = tf.reduce_any(
                tf.not_equal(
                    prev_out,
                    tf.one_hot([0], FEATURE_SIZE)  # <eos>
                )
            )

            return tf.logical_and(less, is_regular_word)

        def body(loop_cnt, prev_out, prev_state, result):
            cell_output, state = decoder_cell(prev_out, prev_state)
            num_outputs = decoder_cell.output_size
            output = decoder_projection(
                cell_output,
                num_outputs=num_outputs,
                reuse=True
            )
            arg_max = tf.arg_max(output, dimension=1)
            one_hot_output = tf.one_hot(
                indices=arg_max,
                depth=num_outputs
            )

            return (
                tf.add(loop_cnt, 1),
                one_hot_output,
                state,
                result.write(result.size(), tf.cast(one_hot_output, dtype=tf.int8))
            )

        _, __, ___, inferred = tf.while_loop(
            cond,
            body,
            loop_vars=(
                tf.constant(0),
                context,
                decoder_cell.zero_state(batch_size=1, dtype=tf.float32),
                tf.TensorArray(tf.int8, size=0, dynamic_size=True)
            )
        )

        return inferred.stack()
开发者ID:ninotoshi,项目名称:playground,代码行数:56,代码来源:main.py


示例4: decode_for_training

def decode_for_training(cell, final_enc_state, labels):
    # [actual batch size, max seq len, decoder cell size]
    tf.assert_rank(labels, 3)

    cell_size = cell.output_size
    context = bridge(final_enc_state, cell_size)

    # [actual batch size, decoder cell size]
    assert context.get_shape().as_list() == [None, cell_size]

    # tf.shape(labels): tuple of 1 element
    batch_size = tf.shape(labels)[0]  # type: tf.Tensor of rank 0
    max_time_step = labels.get_shape()[1].value

    with tf.variable_scope('decoder'):
        def cond(loop_cnt, _, __, ___):
            return tf.less(loop_cnt, max_time_step)

        def body(loop_cnt, prev_label, prev_state, losses):
            cell_output, state = cell(prev_label, prev_state)
            output = decoder_projection(cell_output, cell_size)

            # cut out the `loop_cnt`-th label
            label = tf.reshape(
                tf.slice(labels, begin=[0, loop_cnt, 0], size=[batch_size, 1, cell_size]),
                shape=[batch_size, cell_size]
            )

            # loss for output past the last time step is calculated to be 0
            loss = tf.nn.softmax_cross_entropy_with_logits(
                logits=output,
                labels=label
            )

            return (
                tf.add(loop_cnt, 1),
                # pass the label as the output of the current step
                label,
                state,
                losses.write(loop_cnt, loss)
            )

        _, _, _, result_loss = tf.while_loop(
            cond,
            body,
            loop_vars=(
                tf.constant(0),
                context,
                cell.zero_state(batch_size=batch_size, dtype=tf.float32),
                tf.TensorArray(tf.float32, size=0, dynamic_size=True)
            ),
        )

        losses = tf.reduce_sum(result_loss.stack(), axis=0)
        time_steps = tf.reduce_sum(tf.reduce_sum(labels, axis=2), axis=1)
        return tf.div(losses, time_steps)
开发者ID:ninotoshi,项目名称:playground,代码行数:56,代码来源:main.py


示例5: bridge

def bridge(final_enc_state, decoder_cell_size, reuse=False):
    tf.assert_rank(final_enc_state, 2)

    with tf.variable_scope('bridge', reuse=reuse):
        context = tf.contrib.layers.fully_connected(
            inputs=final_enc_state,
            num_outputs=decoder_cell_size,
            activation_fn=tf.tanh
        )

        return context
开发者ID:ninotoshi,项目名称:playground,代码行数:11,代码来源:main.py


示例6: _check_valid_event_ndims

  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
    """Check whether event_ndims is atleast min_event_ndims."""
    event_ndims = tf.convert_to_tensor(event_ndims, name="event_ndims")
    event_ndims_ = tf.contrib.util.constant_value(event_ndims)
    assertions = []

    if not event_ndims.dtype.is_integer:
      raise ValueError("Expected integer dtype, got dtype {}".format(
          event_ndims.dtype))

    if event_ndims_ is not None:
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar event_ndims, got shape {}".format(
            event_ndims.shape))
      if min_event_ndims > event_ndims_:
        raise ValueError("event_ndims ({}) must be larger than "
                         "min_event_ndims ({})".format(event_ndims_,
                                                       min_event_ndims))
    elif self.validate_args:
      assertions += [tf.assert_greater_equal(event_ndims, min_event_ndims)]

    if event_ndims.shape.is_fully_defined():
      if event_ndims.shape.ndims != 0:
        raise ValueError("Expected scalar shape, got ndims {}".format(
            event_ndims.shape.ndims))

    elif self.validate_args:
      assertions += [tf.assert_rank(event_ndims, 0, message="Expected scalar.")]
    return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:29,代码来源:bijector.py


示例7: test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank

 def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
   with self.test_session():
     tensor = tf.placeholder(tf.float32, name="my_tensor")
     desired_rank = 2
     with tf.control_dependencies([tf.assert_rank(tensor, desired_rank)]):
       with self.assertRaisesOpError("my_tensor.*rank"):
         tf.identity(tensor).eval(feed_dict={tensor: [1, 2]})
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例8: test_rank_one_tensor_raises_if_rank_too_small_static_rank

 def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
   with self.test_session():
     tensor = tf.constant([1, 2], name="my_tensor")
     desired_rank = 2
     with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
       with tf.control_dependencies([tf.assert_rank(tensor, desired_rank)]):
         tf.identity(tensor).eval()
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例9: _maybe_validate_perm

def _maybe_validate_perm(perm, validate_args, name=None):
  """Checks that `perm` is valid."""
  with tf.name_scope(name, 'maybe_validate_perm', [perm]):
    assertions = []
    if not perm.dtype.is_integer:
      raise TypeError('`perm` must be integer type')

    msg = '`perm` must be a vector.'
    if perm.shape.ndims is not None:
      if perm.shape.ndims != 1:
        raise ValueError(
            msg[:-1] + ', saw rank: {}.'.format(perm.shape.ndims))
    elif validate_args:
      assertions += [tf.assert_rank(perm, 1, message=msg)]

    perm_ = tf.contrib.util.constant_value(perm)
    msg = '`perm` must be a valid permutation vector.'
    if perm_ is not None:
      if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
        raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
    elif validate_args:
      assertions += [tf.assert_equal(
          tf.contrib.framework.sort(perm),
          tf.range(tf.size(perm)),
          message=msg)]

    return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:27,代码来源:transpose.py


示例10: _maybe_validate_rightmost_transposed_ndims

def _maybe_validate_rightmost_transposed_ndims(
    rightmost_transposed_ndims, validate_args, name=None):
  """Checks that `rightmost_transposed_ndims` is valid."""
  with tf.name_scope(name, 'maybe_validate_rightmost_transposed_ndims',
                     [rightmost_transposed_ndims]):
    assertions = []
    if not rightmost_transposed_ndims.dtype.is_integer:
      raise TypeError('`rightmost_transposed_ndims` must be integer type.')

    if rightmost_transposed_ndims.shape.ndims is not None:
      if rightmost_transposed_ndims.shape.ndims != 0:
        raise ValueError('`rightmost_transposed_ndims` must be a scalar, '
                         'saw rank: {}.'.format(
                             rightmost_transposed_ndims.shape.ndims))
    elif validate_args:
      assertions += [tf.assert_rank(rightmost_transposed_ndims, 0)]

    rightmost_transposed_ndims_ = tf.contrib.util.constant_value(
        rightmost_transposed_ndims)
    msg = '`rightmost_transposed_ndims` must be non-negative.'
    if rightmost_transposed_ndims_ is not None:
      if rightmost_transposed_ndims_ < 0:
        raise ValueError(msg[:-1] + ', saw: {}.'.format(
            rightmost_transposed_ndims_))
    elif validate_args:
      assertions += [tf.assert_non_negative(
          rightmost_transposed_ndims, message=msg)]

    return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:29,代码来源:transpose.py


示例11: test_raises_if_rank_is_not_scalar_dynamic

 def test_raises_if_rank_is_not_scalar_dynamic(self):
   with self.test_session():
     tensor = tf.constant([1, 2], dtype=tf.float32, name="my_tensor")
     rank_tensor = tf.placeholder(tf.int32, name="rank_tensor")
     with self.assertRaisesOpError("Rank must be a scalar"):
       with tf.control_dependencies([tf.assert_rank(tensor, rank_tensor)]):
         tf.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例12: test_rank_zero_tensor_raises_if_rank_too_small_static_rank

 def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
     with self.test_session():
         tensor = tf.constant(1, name="my_tensor")
         desired_rank = 1
         with self.assertRaisesRegexp(ValueError, "fail.*my_tensor.*must have rank 1"):
             with tf.control_dependencies([tf.assert_rank(tensor, desired_rank, message="fail")]):
                 tf.identity(tensor).eval()
开发者ID:BloodD,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例13: test_raises_if_rank_is_not_integer_dynamic

 def test_raises_if_rank_is_not_integer_dynamic(self):
     with self.test_session():
         tensor = tf.constant([1, 2], dtype=tf.float32, name="my_tensor")
         rank_tensor = tf.placeholder(tf.float32, name="rank_tensor")
         with self.assertRaisesRegexp(TypeError, "must be of type <dtype: 'int32'>"):
             with tf.control_dependencies([tf.assert_rank(tensor, rank_tensor)]):
                 tf.identity(tensor).eval(feed_dict={rank_tensor: 0.5})
开发者ID:BloodD,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py


示例14: _assert_tensor_shape

def _assert_tensor_shape(tensor, shape, display_name):
    assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name))

    tensor_shape = tensor.get_shape().as_list() if len(shape) else []

    wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape)
                       if cor_dim is not None and ten_dim != cor_dim]
    assert not wrong_dimension, \
        '{} has wrong shape.  Found {}'.format(display_name, tensor_shape)
开发者ID:HarshSharma12,项目名称:CarND-Semantic-Segmentation,代码行数:9,代码来源:project_tests.py


示例15: encode

def encode(inputs, cell, reuse=False):
    tf.assert_rank(inputs, 3)

    time_steps = tf.reduce_sum(tf.reduce_sum(inputs, axis=2), axis=1)

    with tf.variable_scope('encoder', reuse=reuse):
        embedded = tf.contrib.layers.fully_connected(
            inputs=inputs,
            num_outputs=cell.output_size,
            activation_fn=tf.sigmoid
        )

        tf.assert_rank(embedded, 3)

        _, final_state_tuple = tf.nn.dynamic_rnn(
            cell,
            embedded,
            sequence_length=time_steps,
            dtype=tf.float32,
        )

        return final_state_tuple
开发者ID:ninotoshi,项目名称:playground,代码行数:22,代码来源:main.py


示例16: op

def op(name,
       images,
       max_outputs=3,
       display_name=None,
       description=None,
       collections=None):
  """Create an image summary op for use in a TensorFlow graph.

  Arguments:
    name: A unique name for the generated summary node.
    images: A `Tensor` representing pixel data with shape `[k, w, h, c]`,
      where `k` is the number of images, `w` and `h` are the width and
      height of the images, and `c` is the number of channels, which
      should be 1, 3, or 4. Any of the dimensions may be statically
      unknown (i.e., `None`).
    max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
      many images will be emitted at each step. When more than
      `max_outputs` many images are provided, the first `max_outputs` many
      images will be used and the rest silently discarded.
    display_name: Optional name for this summary in TensorBoard, as a
      constant `str`. Defaults to `name`.
    description: Optional long-form description for this summary, as a
      constant `str`. Markdown is supported. Defaults to empty.
    collections: Optional list of graph collections keys. The new
      summary op is added to these collections. Defaults to
      `[Graph Keys.SUMMARIES]`.

  Returns:
    A TensorFlow summary op.
  """
  if display_name is None:
    display_name = name
  summary_metadata = metadata.create_summary_metadata(
      display_name=display_name, description=description)
  with tf.name_scope(name), \
       tf.control_dependencies([tf.assert_rank(images, 4),
                                tf.assert_type(images, tf.uint8),
                                tf.assert_non_negative(max_outputs)]):
    limited_images = images[:max_outputs]
    encoded_images = tf.map_fn(tf.image.encode_png, limited_images,
                               dtype=tf.string,
                               name='encode_each_image')
    image_shape = tf.shape(images)
    dimensions = tf.stack([tf.as_string(image_shape[1], name='width'),
                           tf.as_string(image_shape[2], name='height')],
                          name='dimensions')
    tensor = tf.concat([dimensions, encoded_images], axis=0)
    return tf.summary.tensor_summary(name='image_summary',
                                     tensor=tensor,
                                     collections=collections,
                                     summary_metadata=summary_metadata)
开发者ID:jlewi,项目名称:tensorboard,代码行数:51,代码来源:summary.py


示例17: _maybe_validate_shape_override

  def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
                                     validate_args, name):
    """Helper to __init__ which ensures override batch/event_shape are valid."""
    if override_shape is None:
      override_shape = []

    override_shape = tf.convert_to_tensor(override_shape, dtype=tf.int32,
                                          name=name)

    if not override_shape.dtype.is_integer:
      raise TypeError("shape override must be an integer")

    override_is_scalar = _is_scalar_from_shape(override_shape)
    if tf.contrib.util.constant_value(override_is_scalar):
      return self._empty

    dynamic_assertions = []

    if override_shape.shape.ndims is not None:
      if override_shape.shape.ndims != 1:
        raise ValueError("shape override must be a vector")
    elif validate_args:
      dynamic_assertions += [tf.assert_rank(
          override_shape, 1,
          message="shape override must be a vector")]

    if tf.contrib.util.constant_value(override_shape) is not None:
      if any(s <= 0 for s in tf.contrib.util.constant_value(override_shape)):
        raise ValueError("shape override must have positive elements")
    elif validate_args:
      dynamic_assertions += [tf.assert_positive(
          override_shape,
          message="shape override must have positive elements")]

    is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
                                     _logical_not(override_is_scalar))
    if tf.contrib.util.constant_value(is_both_nonscalar) is not None:
      if tf.contrib.util.constant_value(is_both_nonscalar):
        raise ValueError("base distribution not scalar")
    elif validate_args:
      dynamic_assertions += [tf.assert_equal(
          is_both_nonscalar, False,
          message="base distribution not scalar")]

    if not dynamic_assertions:
      return override_shape
    return control_flow_ops.with_dependencies(
        dynamic_assertions, override_shape)
开发者ID:asudomoeva,项目名称:probability,代码行数:48,代码来源:transformed_distribution.py


示例18: _assert_non_negative_int32_scalar

 def _assert_non_negative_int32_scalar(self, x):
   """Helper which ensures that input is a non-negative, int32, scalar."""
   x = tf.convert_to_tensor(x, name="x")
   if x.dtype.base_dtype != tf.int32.base_dtype:
     raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, tf.int32))
   x_value_static = tensor_util.constant_value(x)
   if x.get_shape().ndims is not None and x_value_static is not None:
     if x.get_shape().ndims != 0:
       raise ValueError("%s.ndims=%d is not 0 (scalar)" %
                        (x.name, x.get_shape().ndims))
     if x_value_static < 0:
       raise ValueError("%s.value=%d cannot be negative" %
                        (x.name, x_value_static))
     return x
   if self.validate_args:
     x = control_flow_ops.with_dependencies(
         [tf.assert_rank(x, 0),
          tf.assert_non_negative(x)], x)
   return x
开发者ID:lewisKit,项目名称:probability,代码行数:19,代码来源:shape.py


示例19: test_model_inputs

def test_model_inputs(model_inputs):
    with tf.Graph().as_default():
        input_data, targets, lr, keep_prob, target_sequence_length, max_target_sequence_length, source_sequence_length = model_inputs()

        # Check type
        assert input_data.op.type == 'Placeholder',\
            'Input is not a Placeholder.'
        assert targets.op.type == 'Placeholder',\
            'Targets is not a Placeholder.'
        assert lr.op.type == 'Placeholder',\
            'Learning Rate is not a Placeholder.'
        assert keep_prob.op.type == 'Placeholder', \
            'Keep Probability is not a Placeholder.'
        assert target_sequence_length.op.type == 'Placeholder', \
            'Target Sequence Length is not a Placeholder.'
        assert max_target_sequence_length.op.type == 'Max', \
            'Max Target Sequence Length is not a Placeholder.'
        assert source_sequence_length.op.type == 'Placeholder', \
            'Source Sequence Length is not a Placeholder.'

        # Check name
        assert input_data.name == 'input:0',\
            'Input has bad name.  Found name {}'.format(input_data.name)
        assert target_sequence_length.name == 'target_sequence_length:0',\
            'Target Sequence Length has bad name.  Found name {}'.format(target_sequence_length.name)
        assert source_sequence_length.name == 'source_sequence_length:0',\
            'Source Sequence Length has bad name.  Found name {}'.format(source_sequence_length.name)
        assert keep_prob.name == 'keep_prob:0', \
            'Keep Probability has bad name.  Found name {}'.format(keep_prob.name)

        assert tf.assert_rank(input_data, 2, message='Input data has wrong rank')
        assert tf.assert_rank(targets, 2, message='Targets has wrong rank')
        assert tf.assert_rank(lr, 0, message='Learning Rate has wrong rank')
        assert tf.assert_rank(keep_prob, 0, message='Keep Probability has wrong rank')
        assert tf.assert_rank(target_sequence_length, 1, message='Target Sequence Length has wrong rank')
        assert tf.assert_rank(max_target_sequence_length, 0, message='Max Target Sequence Length has wrong rank')
        assert tf.assert_rank(source_sequence_length, 1, message='Source Sequence Lengthhas wrong rank')

    _print_success_message()
开发者ID:3man1992,项目名称:Deep_Learning_RNN_Language_Translator,代码行数:39,代码来源:problem_unittests.py


示例20: op

def op(name,
       audio,
       sample_rate,
       labels=None,
       max_outputs=3,
       encoding=None,
       display_name=None,
       description=None,
       collections=None):
  """Create an audio summary op for use in a TensorFlow graph.

  Arguments:
    name: A unique name for the generated summary node.
    audio: A `Tensor` representing audio data with shape `[k, t, c]`,
      where `k` is the number of audio clips, `t` is the number of
      frames, and `c` is the number of channels. Elements should be
      floating-point values in `[-1.0, 1.0]`. Any of the dimensions may
      be statically unknown (i.e., `None`).
    sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the
      sample rate, in Hz. Must be positive.
    labels: Optional `string` `Tensor`, a vector whose length is the
      first dimension of `audio`, where `labels[i]` contains arbitrary
      textual information about `audio[i]`. (For instance, this could be
      some text that a TTS system was supposed to produce.) Markdown is
      supported. Contents should be UTF-8.
    max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
      many audio clips will be emitted at each step. When more than
      `max_outputs` many clips are provided, the first `max_outputs`
      many clips will be used and the rest silently discarded.
    encoding: A constant `str` (not string tensor) indicating the
      desired encoding. You can choose any format you like, as long as
      it's "wav". Please see the "API compatibility note" below.
    display_name: Optional name for this summary in TensorBoard, as a
      constant `str`. Defaults to `name`.
    description: Optional long-form description for this summary, as a
      constant `str`. Markdown is supported. Defaults to empty.
    collections: Optional list of graph collections keys. The new
      summary op is added to these collections. Defaults to
      `[Graph Keys.SUMMARIES]`.

  Returns:
    A TensorFlow summary op.

  API compatibility note: The default value of the `encoding`
  argument is _not_ guaranteed to remain unchanged across TensorBoard
  versions. In the future, we will by default encode as FLAC instead of
  as WAV. If the specific format is important to you, please provide a
  file format explicitly.
  """

  if display_name is None:
    display_name = name
  if encoding is None:
    encoding = 'wav'

  if encoding == 'wav':
    encoding = metadata.Encoding.Value('WAV')
    encoder = functools.partial(tf.contrib.ffmpeg.encode_audio,
                                samples_per_second=sample_rate,
                                file_format='wav')
  else:
    raise ValueError('Unknown encoding: %r' % encoding)

  with tf.name_scope(name), \
       tf.control_dependencies([tf.assert_rank(audio, 3)]):
    limited_audio = audio[:max_outputs]
    encoded_audio = tf.map_fn(encoder, limited_audio,
                              dtype=tf.string,
                              name='encode_each_audio')
    if labels is None:
      limited_labels = tf.tile([''], tf.shape(limited_audio)[:1])
    else:
      limited_labels = labels[:max_outputs]
    tensor = tf.transpose(tf.stack([encoded_audio, limited_labels]))
    summary_metadata = metadata.create_summary_metadata(
        display_name=display_name,
        description=description,
        encoding=encoding)
    return tf.summary.tensor_summary(name='audio_summary',
                                     tensor=tensor,
                                     collections=collections,
                                     summary_metadata=summary_metadata)
开发者ID:jlewi,项目名称:tensorboard,代码行数:82,代码来源:summary.py



注:本文中的tensorflow.assert_rank函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。


鲜花

握手

雷人

路过

鸡蛋
该文章已有0人参与评论

请发表评论

全部评论

专题导读
上一篇:
Python tensorflow.assert_rank_at_least函数代码示例发布时间:2022-05-27
下一篇:
Python tensorflow.assert_positive函数代码示例发布时间:2022-05-27
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap