本文整理汇总了Python中tensorflow.assert_rank函数的典型用法代码示例。如果您正苦于以下问题:Python assert_rank函数的具体用法?Python assert_rank怎么用?Python assert_rank使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了assert_rank函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: test_model_inputs
def test_model_inputs(model_inputs):
with tf.Graph().as_default():
input_data, targets, lr, keep_prob = model_inputs()
# Check type
assert input_data.op.type == 'Placeholder',\
'Input is not a Placeholder.'
assert targets.op.type == 'Placeholder',\
'Targets is not a Placeholder.'
assert lr.op.type == 'Placeholder',\
'Learning Rate is not a Placeholder.'
assert keep_prob.op.type == 'Placeholder', \
'Keep Probability is not a Placeholder.'
# Check name
assert input_data.name == 'input:0',\
'Input has bad name. Found name {}'.format(input_data.name)
assert keep_prob.name == 'keep_prob:0', \
'Keep Probability has bad name. Found name {}'.format(keep_prob.name)
assert tf.assert_rank(input_data, 2, message='Input data has wrong rank')
assert tf.assert_rank(targets, 2, message='Targets has wrong rank')
assert tf.assert_rank(lr, 0, message='Learning Rate has wrong rank')
assert tf.assert_rank(keep_prob, 0, message='Keep Probability has wrong rank')
_print_success_message()
开发者ID:DavidWhois,项目名称:English_to_French_translation,代码行数:26,代码来源:problem_unittests.py
示例2: calculate_reshape
def calculate_reshape(original_shape, new_shape, validate=False, name=None):
"""Calculates the reshaped dimensions (replacing up to one -1 in reshape)."""
batch_shape_static = tensor_util.constant_value_as_shape(new_shape)
if batch_shape_static.is_fully_defined():
return np.int32(batch_shape_static.as_list()), batch_shape_static, []
with tf.name_scope(name, "calculate_reshape", [original_shape, new_shape]):
original_size = tf.reduce_prod(original_shape)
implicit_dim = tf.equal(new_shape, -1)
size_implicit_dim = (
original_size // tf.maximum(1, -tf.reduce_prod(new_shape)))
new_ndims = tf.shape(new_shape)
expanded_new_shape = tf.where( # Assumes exactly one `-1`.
implicit_dim, tf.fill(new_ndims, size_implicit_dim), new_shape)
validations = [] if not validate else [
tf.assert_rank(
original_shape, 1, message="Original shape must be a vector."),
tf.assert_rank(new_shape, 1, message="New shape must be a vector."),
tf.assert_less_equal(
tf.count_nonzero(implicit_dim, dtype=tf.int32),
1,
message="At most one dimension can be unknown."),
tf.assert_positive(
expanded_new_shape, message="Shape elements must be >=-1."),
tf.assert_equal(
tf.reduce_prod(expanded_new_shape),
original_size,
message="Shape sizes do not match."),
]
return expanded_new_shape, batch_shape_static, validations
开发者ID:lewisKit,项目名称:probability,代码行数:29,代码来源:batch_reshape.py
示例3: infer
def infer(encoder_cell, decoder_cell, sentences):
tf.assert_rank(sentences, 3)
assert sentences.get_shape()[0].value == 1 # batch size
assert sentences.get_shape()[2].value == FEATURE_SIZE
# stops generating output if the length reaches the double of the source
output_len_threshold = sentences.get_shape()[1].value * 2
final_state_tuple = encode(sentences, encoder_cell, reuse=True)
context = bridge(final_state_tuple.c, decoder_cell.output_size, reuse=True)
with tf.variable_scope('decoder', reuse=True):
def cond(loop_cnt, prev_out, _, __):
less = tf.less(loop_cnt, output_len_threshold)
is_regular_word = tf.reduce_any(
tf.not_equal(
prev_out,
tf.one_hot([0], FEATURE_SIZE) # <eos>
)
)
return tf.logical_and(less, is_regular_word)
def body(loop_cnt, prev_out, prev_state, result):
cell_output, state = decoder_cell(prev_out, prev_state)
num_outputs = decoder_cell.output_size
output = decoder_projection(
cell_output,
num_outputs=num_outputs,
reuse=True
)
arg_max = tf.arg_max(output, dimension=1)
one_hot_output = tf.one_hot(
indices=arg_max,
depth=num_outputs
)
return (
tf.add(loop_cnt, 1),
one_hot_output,
state,
result.write(result.size(), tf.cast(one_hot_output, dtype=tf.int8))
)
_, __, ___, inferred = tf.while_loop(
cond,
body,
loop_vars=(
tf.constant(0),
context,
decoder_cell.zero_state(batch_size=1, dtype=tf.float32),
tf.TensorArray(tf.int8, size=0, dynamic_size=True)
)
)
return inferred.stack()
开发者ID:ninotoshi,项目名称:playground,代码行数:56,代码来源:main.py
示例4: decode_for_training
def decode_for_training(cell, final_enc_state, labels):
# [actual batch size, max seq len, decoder cell size]
tf.assert_rank(labels, 3)
cell_size = cell.output_size
context = bridge(final_enc_state, cell_size)
# [actual batch size, decoder cell size]
assert context.get_shape().as_list() == [None, cell_size]
# tf.shape(labels): tuple of 1 element
batch_size = tf.shape(labels)[0] # type: tf.Tensor of rank 0
max_time_step = labels.get_shape()[1].value
with tf.variable_scope('decoder'):
def cond(loop_cnt, _, __, ___):
return tf.less(loop_cnt, max_time_step)
def body(loop_cnt, prev_label, prev_state, losses):
cell_output, state = cell(prev_label, prev_state)
output = decoder_projection(cell_output, cell_size)
# cut out the `loop_cnt`-th label
label = tf.reshape(
tf.slice(labels, begin=[0, loop_cnt, 0], size=[batch_size, 1, cell_size]),
shape=[batch_size, cell_size]
)
# loss for output past the last time step is calculated to be 0
loss = tf.nn.softmax_cross_entropy_with_logits(
logits=output,
labels=label
)
return (
tf.add(loop_cnt, 1),
# pass the label as the output of the current step
label,
state,
losses.write(loop_cnt, loss)
)
_, _, _, result_loss = tf.while_loop(
cond,
body,
loop_vars=(
tf.constant(0),
context,
cell.zero_state(batch_size=batch_size, dtype=tf.float32),
tf.TensorArray(tf.float32, size=0, dynamic_size=True)
),
)
losses = tf.reduce_sum(result_loss.stack(), axis=0)
time_steps = tf.reduce_sum(tf.reduce_sum(labels, axis=2), axis=1)
return tf.div(losses, time_steps)
开发者ID:ninotoshi,项目名称:playground,代码行数:56,代码来源:main.py
示例5: bridge
def bridge(final_enc_state, decoder_cell_size, reuse=False):
tf.assert_rank(final_enc_state, 2)
with tf.variable_scope('bridge', reuse=reuse):
context = tf.contrib.layers.fully_connected(
inputs=final_enc_state,
num_outputs=decoder_cell_size,
activation_fn=tf.tanh
)
return context
开发者ID:ninotoshi,项目名称:playground,代码行数:11,代码来源:main.py
示例6: _check_valid_event_ndims
def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
"""Check whether event_ndims is atleast min_event_ndims."""
event_ndims = tf.convert_to_tensor(event_ndims, name="event_ndims")
event_ndims_ = tf.contrib.util.constant_value(event_ndims)
assertions = []
if not event_ndims.dtype.is_integer:
raise ValueError("Expected integer dtype, got dtype {}".format(
event_ndims.dtype))
if event_ndims_ is not None:
if event_ndims.shape.ndims != 0:
raise ValueError("Expected scalar event_ndims, got shape {}".format(
event_ndims.shape))
if min_event_ndims > event_ndims_:
raise ValueError("event_ndims ({}) must be larger than "
"min_event_ndims ({})".format(event_ndims_,
min_event_ndims))
elif self.validate_args:
assertions += [tf.assert_greater_equal(event_ndims, min_event_ndims)]
if event_ndims.shape.is_fully_defined():
if event_ndims.shape.ndims != 0:
raise ValueError("Expected scalar shape, got ndims {}".format(
event_ndims.shape.ndims))
elif self.validate_args:
assertions += [tf.assert_rank(event_ndims, 0, message="Expected scalar.")]
return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:29,代码来源:bijector.py
示例7: test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank
def test_rank_one_tensor_raises_if_rank_too_small_dynamic_rank(self):
with self.test_session():
tensor = tf.placeholder(tf.float32, name="my_tensor")
desired_rank = 2
with tf.control_dependencies([tf.assert_rank(tensor, desired_rank)]):
with self.assertRaisesOpError("my_tensor.*rank"):
tf.identity(tensor).eval(feed_dict={tensor: [1, 2]})
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例8: test_rank_one_tensor_raises_if_rank_too_small_static_rank
def test_rank_one_tensor_raises_if_rank_too_small_static_rank(self):
with self.test_session():
tensor = tf.constant([1, 2], name="my_tensor")
desired_rank = 2
with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
with tf.control_dependencies([tf.assert_rank(tensor, desired_rank)]):
tf.identity(tensor).eval()
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例9: _maybe_validate_perm
def _maybe_validate_perm(perm, validate_args, name=None):
"""Checks that `perm` is valid."""
with tf.name_scope(name, 'maybe_validate_perm', [perm]):
assertions = []
if not perm.dtype.is_integer:
raise TypeError('`perm` must be integer type')
msg = '`perm` must be a vector.'
if perm.shape.ndims is not None:
if perm.shape.ndims != 1:
raise ValueError(
msg[:-1] + ', saw rank: {}.'.format(perm.shape.ndims))
elif validate_args:
assertions += [tf.assert_rank(perm, 1, message=msg)]
perm_ = tf.contrib.util.constant_value(perm)
msg = '`perm` must be a valid permutation vector.'
if perm_ is not None:
if not np.all(np.arange(np.size(perm_)) == np.sort(perm_)):
raise ValueError(msg[:-1] + ', saw: {}.'.format(perm_))
elif validate_args:
assertions += [tf.assert_equal(
tf.contrib.framework.sort(perm),
tf.range(tf.size(perm)),
message=msg)]
return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:27,代码来源:transpose.py
示例10: _maybe_validate_rightmost_transposed_ndims
def _maybe_validate_rightmost_transposed_ndims(
rightmost_transposed_ndims, validate_args, name=None):
"""Checks that `rightmost_transposed_ndims` is valid."""
with tf.name_scope(name, 'maybe_validate_rightmost_transposed_ndims',
[rightmost_transposed_ndims]):
assertions = []
if not rightmost_transposed_ndims.dtype.is_integer:
raise TypeError('`rightmost_transposed_ndims` must be integer type.')
if rightmost_transposed_ndims.shape.ndims is not None:
if rightmost_transposed_ndims.shape.ndims != 0:
raise ValueError('`rightmost_transposed_ndims` must be a scalar, '
'saw rank: {}.'.format(
rightmost_transposed_ndims.shape.ndims))
elif validate_args:
assertions += [tf.assert_rank(rightmost_transposed_ndims, 0)]
rightmost_transposed_ndims_ = tf.contrib.util.constant_value(
rightmost_transposed_ndims)
msg = '`rightmost_transposed_ndims` must be non-negative.'
if rightmost_transposed_ndims_ is not None:
if rightmost_transposed_ndims_ < 0:
raise ValueError(msg[:-1] + ', saw: {}.'.format(
rightmost_transposed_ndims_))
elif validate_args:
assertions += [tf.assert_non_negative(
rightmost_transposed_ndims, message=msg)]
return assertions
开发者ID:asudomoeva,项目名称:probability,代码行数:29,代码来源:transpose.py
示例11: test_raises_if_rank_is_not_scalar_dynamic
def test_raises_if_rank_is_not_scalar_dynamic(self):
with self.test_session():
tensor = tf.constant([1, 2], dtype=tf.float32, name="my_tensor")
rank_tensor = tf.placeholder(tf.int32, name="rank_tensor")
with self.assertRaisesOpError("Rank must be a scalar"):
with tf.control_dependencies([tf.assert_rank(tensor, rank_tensor)]):
tf.identity(tensor).eval(feed_dict={rank_tensor: [1, 2]})
开发者ID:3kwa,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例12: test_rank_zero_tensor_raises_if_rank_too_small_static_rank
def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
with self.test_session():
tensor = tf.constant(1, name="my_tensor")
desired_rank = 1
with self.assertRaisesRegexp(ValueError, "fail.*my_tensor.*must have rank 1"):
with tf.control_dependencies([tf.assert_rank(tensor, desired_rank, message="fail")]):
tf.identity(tensor).eval()
开发者ID:BloodD,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例13: test_raises_if_rank_is_not_integer_dynamic
def test_raises_if_rank_is_not_integer_dynamic(self):
with self.test_session():
tensor = tf.constant([1, 2], dtype=tf.float32, name="my_tensor")
rank_tensor = tf.placeholder(tf.float32, name="rank_tensor")
with self.assertRaisesRegexp(TypeError, "must be of type <dtype: 'int32'>"):
with tf.control_dependencies([tf.assert_rank(tensor, rank_tensor)]):
tf.identity(tensor).eval(feed_dict={rank_tensor: 0.5})
开发者ID:BloodD,项目名称:tensorflow,代码行数:7,代码来源:check_ops_test.py
示例14: _assert_tensor_shape
def _assert_tensor_shape(tensor, shape, display_name):
assert tf.assert_rank(tensor, len(shape), message='{} has wrong rank'.format(display_name))
tensor_shape = tensor.get_shape().as_list() if len(shape) else []
wrong_dimension = [ten_dim for ten_dim, cor_dim in zip(tensor_shape, shape)
if cor_dim is not None and ten_dim != cor_dim]
assert not wrong_dimension, \
'{} has wrong shape. Found {}'.format(display_name, tensor_shape)
开发者ID:HarshSharma12,项目名称:CarND-Semantic-Segmentation,代码行数:9,代码来源:project_tests.py
示例15: encode
def encode(inputs, cell, reuse=False):
tf.assert_rank(inputs, 3)
time_steps = tf.reduce_sum(tf.reduce_sum(inputs, axis=2), axis=1)
with tf.variable_scope('encoder', reuse=reuse):
embedded = tf.contrib.layers.fully_connected(
inputs=inputs,
num_outputs=cell.output_size,
activation_fn=tf.sigmoid
)
tf.assert_rank(embedded, 3)
_, final_state_tuple = tf.nn.dynamic_rnn(
cell,
embedded,
sequence_length=time_steps,
dtype=tf.float32,
)
return final_state_tuple
开发者ID:ninotoshi,项目名称:playground,代码行数:22,代码来源:main.py
示例16: op
def op(name,
images,
max_outputs=3,
display_name=None,
description=None,
collections=None):
"""Create an image summary op for use in a TensorFlow graph.
Arguments:
name: A unique name for the generated summary node.
images: A `Tensor` representing pixel data with shape `[k, w, h, c]`,
where `k` is the number of images, `w` and `h` are the width and
height of the images, and `c` is the number of channels, which
should be 1, 3, or 4. Any of the dimensions may be statically
unknown (i.e., `None`).
max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
many images will be emitted at each step. When more than
`max_outputs` many images are provided, the first `max_outputs` many
images will be used and the rest silently discarded.
display_name: Optional name for this summary in TensorBoard, as a
constant `str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
collections: Optional list of graph collections keys. The new
summary op is added to these collections. Defaults to
`[Graph Keys.SUMMARIES]`.
Returns:
A TensorFlow summary op.
"""
if display_name is None:
display_name = name
summary_metadata = metadata.create_summary_metadata(
display_name=display_name, description=description)
with tf.name_scope(name), \
tf.control_dependencies([tf.assert_rank(images, 4),
tf.assert_type(images, tf.uint8),
tf.assert_non_negative(max_outputs)]):
limited_images = images[:max_outputs]
encoded_images = tf.map_fn(tf.image.encode_png, limited_images,
dtype=tf.string,
name='encode_each_image')
image_shape = tf.shape(images)
dimensions = tf.stack([tf.as_string(image_shape[1], name='width'),
tf.as_string(image_shape[2], name='height')],
name='dimensions')
tensor = tf.concat([dimensions, encoded_images], axis=0)
return tf.summary.tensor_summary(name='image_summary',
tensor=tensor,
collections=collections,
summary_metadata=summary_metadata)
开发者ID:jlewi,项目名称:tensorboard,代码行数:51,代码来源:summary.py
示例17: _maybe_validate_shape_override
def _maybe_validate_shape_override(self, override_shape, base_is_scalar,
validate_args, name):
"""Helper to __init__ which ensures override batch/event_shape are valid."""
if override_shape is None:
override_shape = []
override_shape = tf.convert_to_tensor(override_shape, dtype=tf.int32,
name=name)
if not override_shape.dtype.is_integer:
raise TypeError("shape override must be an integer")
override_is_scalar = _is_scalar_from_shape(override_shape)
if tf.contrib.util.constant_value(override_is_scalar):
return self._empty
dynamic_assertions = []
if override_shape.shape.ndims is not None:
if override_shape.shape.ndims != 1:
raise ValueError("shape override must be a vector")
elif validate_args:
dynamic_assertions += [tf.assert_rank(
override_shape, 1,
message="shape override must be a vector")]
if tf.contrib.util.constant_value(override_shape) is not None:
if any(s <= 0 for s in tf.contrib.util.constant_value(override_shape)):
raise ValueError("shape override must have positive elements")
elif validate_args:
dynamic_assertions += [tf.assert_positive(
override_shape,
message="shape override must have positive elements")]
is_both_nonscalar = _logical_and(_logical_not(base_is_scalar),
_logical_not(override_is_scalar))
if tf.contrib.util.constant_value(is_both_nonscalar) is not None:
if tf.contrib.util.constant_value(is_both_nonscalar):
raise ValueError("base distribution not scalar")
elif validate_args:
dynamic_assertions += [tf.assert_equal(
is_both_nonscalar, False,
message="base distribution not scalar")]
if not dynamic_assertions:
return override_shape
return control_flow_ops.with_dependencies(
dynamic_assertions, override_shape)
开发者ID:asudomoeva,项目名称:probability,代码行数:48,代码来源:transformed_distribution.py
示例18: _assert_non_negative_int32_scalar
def _assert_non_negative_int32_scalar(self, x):
"""Helper which ensures that input is a non-negative, int32, scalar."""
x = tf.convert_to_tensor(x, name="x")
if x.dtype.base_dtype != tf.int32.base_dtype:
raise TypeError("%s.dtype=%s is not %s" % (x.name, x.dtype, tf.int32))
x_value_static = tensor_util.constant_value(x)
if x.get_shape().ndims is not None and x_value_static is not None:
if x.get_shape().ndims != 0:
raise ValueError("%s.ndims=%d is not 0 (scalar)" %
(x.name, x.get_shape().ndims))
if x_value_static < 0:
raise ValueError("%s.value=%d cannot be negative" %
(x.name, x_value_static))
return x
if self.validate_args:
x = control_flow_ops.with_dependencies(
[tf.assert_rank(x, 0),
tf.assert_non_negative(x)], x)
return x
开发者ID:lewisKit,项目名称:probability,代码行数:19,代码来源:shape.py
示例19: test_model_inputs
def test_model_inputs(model_inputs):
with tf.Graph().as_default():
input_data, targets, lr, keep_prob, target_sequence_length, max_target_sequence_length, source_sequence_length = model_inputs()
# Check type
assert input_data.op.type == 'Placeholder',\
'Input is not a Placeholder.'
assert targets.op.type == 'Placeholder',\
'Targets is not a Placeholder.'
assert lr.op.type == 'Placeholder',\
'Learning Rate is not a Placeholder.'
assert keep_prob.op.type == 'Placeholder', \
'Keep Probability is not a Placeholder.'
assert target_sequence_length.op.type == 'Placeholder', \
'Target Sequence Length is not a Placeholder.'
assert max_target_sequence_length.op.type == 'Max', \
'Max Target Sequence Length is not a Placeholder.'
assert source_sequence_length.op.type == 'Placeholder', \
'Source Sequence Length is not a Placeholder.'
# Check name
assert input_data.name == 'input:0',\
'Input has bad name. Found name {}'.format(input_data.name)
assert target_sequence_length.name == 'target_sequence_length:0',\
'Target Sequence Length has bad name. Found name {}'.format(target_sequence_length.name)
assert source_sequence_length.name == 'source_sequence_length:0',\
'Source Sequence Length has bad name. Found name {}'.format(source_sequence_length.name)
assert keep_prob.name == 'keep_prob:0', \
'Keep Probability has bad name. Found name {}'.format(keep_prob.name)
assert tf.assert_rank(input_data, 2, message='Input data has wrong rank')
assert tf.assert_rank(targets, 2, message='Targets has wrong rank')
assert tf.assert_rank(lr, 0, message='Learning Rate has wrong rank')
assert tf.assert_rank(keep_prob, 0, message='Keep Probability has wrong rank')
assert tf.assert_rank(target_sequence_length, 1, message='Target Sequence Length has wrong rank')
assert tf.assert_rank(max_target_sequence_length, 0, message='Max Target Sequence Length has wrong rank')
assert tf.assert_rank(source_sequence_length, 1, message='Source Sequence Lengthhas wrong rank')
_print_success_message()
开发者ID:3man1992,项目名称:Deep_Learning_RNN_Language_Translator,代码行数:39,代码来源:problem_unittests.py
示例20: op
def op(name,
audio,
sample_rate,
labels=None,
max_outputs=3,
encoding=None,
display_name=None,
description=None,
collections=None):
"""Create an audio summary op for use in a TensorFlow graph.
Arguments:
name: A unique name for the generated summary node.
audio: A `Tensor` representing audio data with shape `[k, t, c]`,
where `k` is the number of audio clips, `t` is the number of
frames, and `c` is the number of channels. Elements should be
floating-point values in `[-1.0, 1.0]`. Any of the dimensions may
be statically unknown (i.e., `None`).
sample_rate: An `int` or rank-0 `int32` `Tensor` that represents the
sample rate, in Hz. Must be positive.
labels: Optional `string` `Tensor`, a vector whose length is the
first dimension of `audio`, where `labels[i]` contains arbitrary
textual information about `audio[i]`. (For instance, this could be
some text that a TTS system was supposed to produce.) Markdown is
supported. Contents should be UTF-8.
max_outputs: Optional `int` or rank-0 integer `Tensor`. At most this
many audio clips will be emitted at each step. When more than
`max_outputs` many clips are provided, the first `max_outputs`
many clips will be used and the rest silently discarded.
encoding: A constant `str` (not string tensor) indicating the
desired encoding. You can choose any format you like, as long as
it's "wav". Please see the "API compatibility note" below.
display_name: Optional name for this summary in TensorBoard, as a
constant `str`. Defaults to `name`.
description: Optional long-form description for this summary, as a
constant `str`. Markdown is supported. Defaults to empty.
collections: Optional list of graph collections keys. The new
summary op is added to these collections. Defaults to
`[Graph Keys.SUMMARIES]`.
Returns:
A TensorFlow summary op.
API compatibility note: The default value of the `encoding`
argument is _not_ guaranteed to remain unchanged across TensorBoard
versions. In the future, we will by default encode as FLAC instead of
as WAV. If the specific format is important to you, please provide a
file format explicitly.
"""
if display_name is None:
display_name = name
if encoding is None:
encoding = 'wav'
if encoding == 'wav':
encoding = metadata.Encoding.Value('WAV')
encoder = functools.partial(tf.contrib.ffmpeg.encode_audio,
samples_per_second=sample_rate,
file_format='wav')
else:
raise ValueError('Unknown encoding: %r' % encoding)
with tf.name_scope(name), \
tf.control_dependencies([tf.assert_rank(audio, 3)]):
limited_audio = audio[:max_outputs]
encoded_audio = tf.map_fn(encoder, limited_audio,
dtype=tf.string,
name='encode_each_audio')
if labels is None:
limited_labels = tf.tile([''], tf.shape(limited_audio)[:1])
else:
limited_labels = labels[:max_outputs]
tensor = tf.transpose(tf.stack([encoded_audio, limited_labels]))
summary_metadata = metadata.create_summary_metadata(
display_name=display_name,
description=description,
encoding=encoding)
return tf.summary.tensor_summary(name='audio_summary',
tensor=tensor,
collections=collections,
summary_metadata=summary_metadata)
开发者ID:jlewi,项目名称:tensorboard,代码行数:82,代码来源:summary.py
注:本文中的tensorflow.assert_rank函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论