本文整理汇总了Python中tensorflow.python.ops.ragged.ragged_tensor.convert_to_tensor_or_ragged_tensor函数的典型用法代码示例。如果您正苦于以下问题:Python convert_to_tensor_or_ragged_tensor函数的具体用法?Python convert_to_tensor_or_ragged_tensor怎么用?Python convert_to_tensor_or_ragged_tensor使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了convert_to_tensor_or_ragged_tensor函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: testBinaryElementwiseOp
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
use_kwargs = extra_args.pop('use_kwargs', ())
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y)
if 'x' in use_kwargs and 'y' in use_kwargs:
result = op(x=x, y=y, **extra_args)
elif 'y' in use_kwargs:
result = op(x, y=y, **extra_args)
else:
result = op(x, y, **extra_args)
# Run the wrapped op on the dense values, for comparison.
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
dense_y = y.flat_values if isinstance(y, ragged_tensor.RaggedTensor) else y
expected_flat_values = array_ops.reshape(
op(dense_x, dense_y, **extra_args), [-1])
# Check that the result has the expected shape.
self.assertSameShape(y, result)
# Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor):
result_flat_values = array_ops.reshape(result.flat_values, [-1])
else:
result_flat_values = array_ops.reshape(result, [-1])
self.assertAllEqual(expected_flat_values, result_flat_values)
开发者ID:aritratony,项目名称:tensorflow,代码行数:26,代码来源:ragged_dispatch_test.py
示例2: testConvertNumpyArrayError
def testConvertNumpyArrayError(self,
value,
message,
dtype=None,
preferred_dtype=None):
with self.assertRaisesRegexp(ValueError, message):
ragged_tensor.convert_to_tensor_or_ragged_tensor(value, dtype,
preferred_dtype)
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:convert_to_tensor_or_ragged_tensor_op_test.py
示例3: testConvertTensorError
def testConvertTensorError(self,
pylist,
message,
dtype=None,
preferred_dtype=None):
tensor = constant_op.constant(pylist)
with self.assertRaisesRegexp(ValueError, message):
ragged_tensor.convert_to_tensor_or_ragged_tensor(tensor, dtype,
preferred_dtype)
开发者ID:aritratony,项目名称:tensorflow,代码行数:9,代码来源:convert_to_tensor_or_ragged_tensor_op_test.py
示例4: testRaggedAddWithBroadcasting
def testRaggedAddWithBroadcasting(self, x, y, expected, doc):
expected_rrank = getattr(expected, 'ragged_rank', 0)
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
result = x + y
result_rrank = getattr(result, 'ragged_rank', 0)
self.assertEqual(expected_rrank, result_rrank)
if hasattr(expected, 'tolist'):
expected = expected.tolist()
self.assertRaggedEqual(result, expected)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:10,代码来源:ragged_tensor_shape_test.py
示例5: testConvertRaggedTensorError
def testConvertRaggedTensorError(self,
pylist,
message,
dtype=None,
preferred_dtype=None):
rt = ragged_factory_ops.constant(pylist)
with self.assertRaisesRegexp(ValueError, message):
ragged_tensor.convert_to_tensor_or_ragged_tensor(rt, dtype,
preferred_dtype)
开发者ID:aritratony,项目名称:tensorflow,代码行数:10,代码来源:convert_to_tensor_or_ragged_tensor_op_test.py
示例6: string_split_v2
def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin
"""Split elements of `input` based on `sep` into a `RaggedTensor`.
Let N be the size of `input` (typically N will be the batch size). Split each
element of `input` based on `sep` and return a `SparseTensor` or
`RaggedTensor` containing the split tokens. Empty tokens are ignored.
Example:
```python
>>> tf.strings.split('hello world')
<Tensor ['hello', 'world']>
>>> tf.strings.split(['hello world', 'a b c'])
<tf.RaggedTensor [['hello', 'world'], ['a', 'b', 'c']]>
```
If `sep` is given, consecutive delimiters are not grouped together and are
deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
`sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
string, consecutive whitespace are regarded as a single separator, and the
result will contain no empty strings at the start or end if the string has
leading or trailing whitespace.
Note that the above mentioned behavior matches python's str.split.
Args:
input: A string `Tensor` of rank `N`, the strings to split. If
`rank(input)` is not known statically, then it is assumed to be `1`.
sep: `0-D` string `Tensor`, the delimiter string.
maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
name: A name for the operation (optional).
Raises:
ValueError: If sep is not a string.
Returns:
A `RaggedTensor` of rank `N+1`, the strings split according to the
delimiter.
"""
with ops.name_scope(name, "StringSplit", [input]):
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
input, dtype=dtypes.string, name="input")
if isinstance(input, ragged_tensor.RaggedTensor):
return input.with_flat_values(
string_split_v2(input.flat_values, sep, maxsplit))
rank = input.shape.ndims
if rank == 0:
return string_split_v2(array_ops.stack([input]), sep, maxsplit)[0]
elif rank == 1 or rank is None:
sparse_result = string_ops.string_split_v2(
input, sep=sep, maxsplit=maxsplit)
return ragged_tensor.RaggedTensor.from_value_rowids(
values=sparse_result.values,
value_rowids=sparse_result.indices[:, 0],
nrows=sparse_result.dense_shape[0],
validate=False)
else:
return string_split_v2(
ragged_tensor.RaggedTensor.from_tensor(input), sep, maxsplit)
开发者ID:aritratony,项目名称:tensorflow,代码行数:60,代码来源:ragged_string_ops.py
示例7: _replace_ragged_with_flat_values
def _replace_ragged_with_flat_values(value, nested_splits_lists):
"""Replace RaggedTensors with their flat_values, and record their splits.
Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
`flat_values` tensor. Looks inside lists, tuples, and dicts.
Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`.
Args:
value: The value that should be transformed by replacing `RaggedTensors`.
nested_splits_lists: An output parameter used to record the `nested_splits`
for any `RaggedTensors` that were replaced.
Returns:
A copy of `value` with nested `RaggedTensors` replaced by their `values`.
"""
# Base case
if ragged_tensor.is_ragged(value):
value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
nested_splits_lists.append(value.nested_row_splits)
return value.flat_values
# Recursion cases
def recurse(v):
return _replace_ragged_with_flat_values(v, nested_splits_lists)
if isinstance(value, list):
return [recurse(v) for v in value]
elif isinstance(value, tuple):
return tuple(recurse(v) for v in value)
elif isinstance(value, dict):
return dict((k, recurse(v)) for (k, v) in value.items())
else:
return value
开发者ID:aritratony,项目名称:tensorflow,代码行数:34,代码来源:ragged_functional_ops.py
示例8: testListValuedElementwiseOp
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
**extra_args):
use_kwargs = extra_args.pop('use_kwargs', False)
inputs = [
ragged_tensor.convert_to_tensor_or_ragged_tensor(x) for x in inputs
]
if use_kwargs:
result = op(inputs=inputs, **extra_args)
else:
result = op(inputs, **extra_args)
# Run the wrapped op on the dense values, for comparison.
dense_inputs = [
x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
for x in inputs
]
expected_flat_values = array_ops.reshape(
op(dense_inputs, **extra_args), [-1])
# Check that the result has the expected shape.
self.assertSameShape(inputs[0], result)
# Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor):
result_flat_values = array_ops.reshape(result.flat_values, [-1])
else:
result_flat_values = array_ops.reshape(result, [-1])
self.assertAllEqual(expected_flat_values, result_flat_values)
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:ragged_dispatch_test.py
示例9: normalize_tensors
def normalize_tensors(tensors):
"""Converts a nested structure of tensor-like objects to tensors.
* `SparseTensor`-like inputs are converted to `SparseTensor`.
* `TensorArray` inputs are passed through.
* Everything else is converted to a dense `Tensor`.
Args:
tensors: A nested structure of tensor-like, list,
`SparseTensor`, `SparseTensorValue`, or `TensorArray` objects.
Returns:
A nested structure of tensor, `SparseTensor`, or `TensorArray` objects.
"""
flat_tensors = nest.flatten(tensors)
prepared = []
with ops.name_scope("normalize_tensors"):
for i, t in enumerate(flat_tensors):
if sparse_tensor_lib.is_sparse(t):
prepared.append(sparse_tensor_lib.SparseTensor.from_value(t))
elif ragged_tensor.is_ragged(t):
prepared.append(
ragged_tensor.convert_to_tensor_or_ragged_tensor(
t, name="component_%d" % i))
elif isinstance(t, tensor_array_ops.TensorArray):
prepared.append(t)
else:
prepared.append(ops.convert_to_tensor(t, name="component_%d" % i))
return nest.pack_sequence_as(tensors, prepared)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:structure.py
示例10: broadcast_to
def broadcast_to(rt_input, shape, broadcast_inner_dimensions=True):
"""Broadcasts a potentially ragged tensor to a ragged shape.
Tiles `rt_input` as necessary to match the given shape.
Behavior is undefined if `rt_input` is not broadcast-compatible with `shape`.
Args:
rt_input: The potentially ragged tensor to broadcast.
shape: A `RaggedTensorDynamicShape`
broadcast_inner_dimensions: If false, then inner dimensions will not be
tiled.
Returns:
A potentially ragged tensor whose values are taken from
`rt_input`, and whose shape matches `shape`.
"""
if not isinstance(shape, RaggedTensorDynamicShape):
raise TypeError('shape must be a RaggedTensorDynamicShape')
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
# Broadcasting to a uniform shape.
if shape.num_partitioned_dimensions == 0:
return _broadcast_to_uniform_shape(rt_input, shape,
broadcast_inner_dimensions)
else:
return _broadcast_to_ragged_shape(rt_input, shape,
broadcast_inner_dimensions)
开发者ID:aritratony,项目名称:tensorflow,代码行数:28,代码来源:ragged_tensor_shape.py
示例11: _unicode_decode
def _unicode_decode(input, input_encoding, errors, replacement_char,
replace_control_characters, with_offsets):
"""Decodes each string into a sequence of codepoints."""
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input")
input_ndims = input.shape.ndims
if input_ndims is None:
raise ValueError("Rank of `input` must be statically known.")
if input_ndims > 1:
# Convert to a ragged tensor with ragged_rank = input_ndims - 1.
if not ragged_tensor.is_ragged(input):
input = ragged_tensor.RaggedTensor.from_tensor(
input, ragged_rank=input_ndims - 1)
elif input.ragged_rank < input_ndims - 1:
input = input.with_flat_values(
ragged_tensor.RaggedTensor.from_tensor(
input.flat_values,
ragged_rank=input_ndims - input.ragged_rank + 1))
# Reshape the input to a flat vector, and apply the gen_string_ops op.
if ragged_tensor.is_ragged(input):
flat_input = array_ops.reshape(input.flat_values, [-1])
else:
flat_input = array_ops.reshape(input, [-1])
if with_offsets:
decode_op = gen_string_ops.unicode_decode_with_offsets
else:
decode_op = gen_string_ops.unicode_decode
flat_result = decode_op(
input=flat_input,
input_encoding=input_encoding,
errors=errors,
replacement_char=replacement_char,
replace_control_characters=replace_control_characters)
if input_ndims == 0:
codepoints = flat_result.char_values
if with_offsets:
offsets = flat_result.char_to_byte_starts
else:
codepoints = ragged_tensor.RaggedTensor.from_row_splits(
flat_result.char_values, flat_result.row_splits, validate=False)
if input_ndims > 1:
codepoints = input.with_flat_values(codepoints)
if with_offsets:
offsets = ragged_tensor.RaggedTensor.from_row_splits(
flat_result.char_to_byte_starts, flat_result.row_splits,
validate=False)
if input_ndims > 1:
offsets = input.with_flat_values(offsets)
if with_offsets:
return codepoints, offsets
else:
return codepoints
开发者ID:aritratony,项目名称:tensorflow,代码行数:56,代码来源:ragged_string_ops.py
示例12: testConvertNumpyArray
def testConvertNumpyArray(self,
value,
dtype=None,
preferred_dtype=None,
expected_dtype=None):
if expected_dtype is None:
expected_dtype = value.dtype if dtype is None else dtype
converted = ragged_tensor.convert_to_tensor_or_ragged_tensor(
value, dtype, preferred_dtype)
self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
self.assertAllEqual(value, converted)
开发者ID:aritratony,项目名称:tensorflow,代码行数:11,代码来源:convert_to_tensor_or_ragged_tensor_op_test.py
示例13: from_tensor
def from_tensor(cls, rt_input):
"""Constructs a ragged shape for a potentially ragged tensor."""
with ops.name_scope(None, 'RaggedTensorDynamicShapeFromTensor', [rt_input]):
rt_input = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt_input)
if not ragged_tensor.is_ragged(rt_input):
return cls([], array_ops.shape(rt_input))
else:
partitioned_dim_sizes = (
(rt_input.nrows(),) + rt_input.nested_row_lengths())
return RaggedTensorDynamicShape(
partitioned_dim_sizes,
array_ops.shape(rt_input.flat_values)[1:])
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:12,代码来源:ragged_tensor_shape.py
示例14: testConvertRaggedTensorValue
def testConvertRaggedTensorValue(self,
value,
dtype=None,
preferred_dtype=None,
expected_dtype=None):
if expected_dtype is None:
expected_dtype = value.dtype if dtype is None else dtype
converted = ragged_tensor.convert_to_tensor_or_ragged_tensor(
value, dtype, preferred_dtype)
self.assertEqual(value.ragged_rank, converted.ragged_rank)
self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
self.assertEqual(value.to_list(), self.eval_to_list(converted))
开发者ID:aritratony,项目名称:tensorflow,代码行数:12,代码来源:convert_to_tensor_or_ragged_tensor_op_test.py
示例15: testUnaryElementwiseOp
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x)
result = op(x, **extra_args)
# Run the wrapped op on the dense values, for comparison.
dense_x = x.flat_values if isinstance(x, ragged_tensor.RaggedTensor) else x
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
# Check that the result has the expected shape.
self.assertSameShape(x, result)
# Check that the result has the expected (flattened) values.
if isinstance(result, ragged_tensor.RaggedTensor):
result_flat_values = array_ops.reshape(result.flat_values, [-1])
else:
result_flat_values = array_ops.reshape(result, [-1])
self.assertAllEqual(expected_flat_values, result_flat_values)
开发者ID:aritratony,项目名称:tensorflow,代码行数:17,代码来源:ragged_dispatch_test.py
示例16: string_bytes_split
def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
"""Split string elements of `input` into bytes.
Examples:
```python
>>> tf.strings.to_bytes('hello')
['h', 'e', 'l', 'l', 'o']
>>> tf.strings.to_bytes(['hello', '123'])
<RaggedTensor [['h', 'e', 'l', 'l', 'o'], ['1', '2', '3']]>
```
Note that this op splits strings into bytes, not unicode characters. To
split strings into unicode characters, use `tf.strings.unicode_split`.
See also: `tf.io.decode_raw`, `tf.strings.split`, `tf.strings.unicode_split`.
Args:
input: A string `Tensor` or `RaggedTensor`: the strings to split. Must
have a statically known rank (`N`).
name: A name for the operation (optional).
Returns:
A `RaggedTensor` of rank `N+1`: the bytes that make up the soruce strings.
"""
with ops.name_scope(name, "StringsByteSplit", [input]):
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input,
name="input")
if isinstance(input, ragged_tensor.RaggedTensor):
return input.with_flat_values(string_bytes_split(input.flat_values))
rank = input.shape.ndims
if rank is None:
raise ValueError("input must have a statically-known rank.")
if rank == 0:
return string_bytes_split(array_ops.stack([input]))[0]
elif rank == 1:
indices, values, shape = gen_string_ops.string_split(
input, delimiter="", skip_empty=False)
return ragged_tensor.RaggedTensor.from_value_rowids(
values=values, value_rowids=indices[:, 0], nrows=shape[0],
validate=False)
else:
return string_bytes_split(ragged_tensor.RaggedTensor.from_tensor(input))
开发者ID:aritratony,项目名称:tensorflow,代码行数:45,代码来源:ragged_string_ops.py
示例17: tile
def tile(input, multiples, name=None): # pylint: disable=redefined-builtin
"""Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
The values of `input` are replicated `multiples[i]` times along the
`i`th dimension (for each dimension `i`). For every dimension `axis` in
`input`, the length of each output element in that dimension is the
length of corresponding input element multiplied by `multiples[axis]`.
Args:
input: A `RaggedTensor`.
multiples: A 1-D integer `Tensor`. Length must be the same as the number of
dimensions in `input`.
name: A name for the operation (optional).
Returns:
A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
#### Example:
```python
>>> rt = tf.ragged.constant([[1, 2], [3]])
>>> ragged.tile(rt, [3, 2])
[[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
```
"""
with ops.name_scope(name, 'RaggedTile', [input, multiples]):
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
input, name='input')
if not ragged_tensor.is_ragged(input):
return array_ops.tile(input, multiples, name)
multiples = ragged_util.convert_to_int_tensor(
multiples, name='multiples', dtype=input.row_splits.dtype)
multiples.shape.assert_has_rank(1)
# If the constant value of `multiples` is available, then we can use it
# to skip tiling dimensions where `multiples=1`.
const_multiples = tensor_util.constant_value(multiples)
return ragged_tensor.RaggedTensor.from_nested_row_splits(
_tile_ragged_values(input, multiples, const_multiples),
_tile_ragged_splits(input, multiples, const_multiples),
validate=False)
开发者ID:aritratony,项目名称:tensorflow,代码行数:41,代码来源:ragged_array_ops.py
示例18: gather_nd
def gather_nd(params, indices, batch_dims=0, name=None):
"""Gather slices from `params` using `n`-dimensional indices.
This operation is similar to `gather`, but it uses the innermost dimension
of `indices` to define a slice into `params`. In particular, if:
* `indices` has shape `[A1...AN, I]`
* `params` has shape `[B1...BM]`
Then:
* `result` has shape `[A1...AN, B_{I+1}...BM]`.
* `result[a1...aN] = params[indices[a1...aN, :]]`
Args:
params: A potentially ragged tensor with shape `[A1...AN, I]`.
indices: A potentially ragged tensor with shape `[B1...BM]`.
batch_dims: Must be zero.
name: A name for the operation (optional).
Returns:
A potentially ragged tensor with shape `[A1...AN, B_{I+1}...BM]`.
#### Examples:
```python
>>> params = tf.compat.v1.ragged.constant_value(
... [ [ ['000', '001'], ['010' ] ],
... [ ['100' ], ['110', '111', '112'], ['120'] ],
... [ [ ], ['210' ] ] ])
>>> # Gather 2D slices from a 3D tensor
>>> ragged.gather_nd(params, [[2], [0]])
[ [ [ ], ['210'] ]
[ ['000', '001'], ['010'] ] ]
>>> # Gather 1D slices from a 3D tensor
>>> ragged.gather_nd(params, [[2, 1], [0, 0]])
[['210'], ['000', '001']]
>>> # Gather scalars from a 3D tensor
>>> ragged.gather_nd(params, [[0, 0, 1], [1, 1, 2]])
['001', '112']
```
"""
if not isinstance(batch_dims, int) or batch_dims != 0:
raise ValueError('batch_dims != 0 is not supported for ragged gather yet.')
if not (ragged_tensor.is_ragged(params) or ragged_tensor.is_ragged(indices)):
return array_ops.gather_nd(params, indices, name)
with ops.name_scope(name, 'RaggedGatherNd', [params, indices]):
params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
params, name='params')
indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
indices, name='indices')
params, indices = ragged_tensor.match_row_splits_dtypes(params, indices)
indices_shape = indices.shape
indices_ndims = indices_shape.ndims
if indices_ndims is None:
raise ValueError('indices.rank be statically known.')
if indices_ndims == 0:
raise ValueError('indices.rank must be at least 1.')
if (ragged_tensor.is_ragged(indices) and
indices_ndims == indices.ragged_rank + 1):
raise ValueError('The innermost dimension of indices may not be ragged')
# `index_size` is the "n" in "gather_nd" -- i.e., the number of dimensions
# that each index slices into.
index_size = tensor_shape.dimension_value(indices_shape[-1])
if index_size is None:
raise ValueError('indices.shape[-1] must be statically known.')
# If `indices` has more than 2 dimensions, then recurse. If `indices` is
# dense, then we convert it to ragged before recursing, and then convert
# the result back to `dense` if appropriate.
if indices_ndims > 2:
indices_is_dense = not ragged_tensor.is_ragged(indices)
if indices_is_dense:
indices = ragged_tensor.RaggedTensor.from_tensor(
indices, ragged_rank=indices_ndims - 2,
row_splits_dtype=params.row_splits.dtype)
result = indices.with_flat_values(gather_nd(params, indices.flat_values))
if (indices_is_dense and ragged_tensor.is_ragged(result) and
result.ragged_rank == indices_ndims - 2):
result = ragged_tensor.RaggedTensor.to_tensor(result)
return result
# indices_ndims <= 2, and the innermost dimension of indices may not be
# ragged, so `indices` must not be ragged.
assert not ragged_tensor.is_ragged(indices)
assert ragged_tensor.is_ragged(params)
# Handle corner case: An empty index tuple selects the entire `params`
# value. So if `index_size` is zero, then tile `params`.
if index_size == 0:
params_ndims = params.ragged_rank + array_ops.rank(params.flat_values)
for dim in range(indices_ndims - 1):
params = ragged_array_ops.expand_dims(params, axis=0)
multiples = array_ops.concat([
array_ops.shape(indices)[:-1],
#.........这里部分代码省略.........
开发者ID:aritratony,项目名称:tensorflow,代码行数:101,代码来源:ragged_gather_ops.py
示例19: testElementwiseOpBroadcast
def testElementwiseOpBroadcast(self, x, y, expected):
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
result = x + y
self.assertRaggedEqual(result, expected)
开发者ID:aritratony,项目名称:tensorflow,代码行数:5,代码来源:ragged_dispatch_test.py
示例20: _ragged_segment_aggregate
def _ragged_segment_aggregate(unsorted_segment_op,
data,
segment_ids,
num_segments,
name=None):
"""Aggregates along segments of a RaggedTensor using `unsorted_segment_op`.
Returns a RaggedTensor `output` with `num_segments` rows, where the row
`output[i]` is formed by combining all rows of `data` whose corresponding
`segment_id` is `i`. The values in each row are combined using
`unsorted_segment_op`.
The length of the row `output[i]` will be the maximum of the lengths of
all rows of `data` whose corresponding `segment_id` is `i`. If no `data`
rows correspond to a given segment ID, then the output row for that segment
ID will be empty.
Args:
unsorted_segment_op: The tensorflow `op` that should be used to combine
values in each row. Must have the same signature and basic behavior as
`unsorted_segment_sum`, `unsorted_segment_max`, etc.
data: A `RaggedTensor` containing the values to be combined.
segment_ids: A `Tensor` or `RaggedTensor`. Must have type `int64` or
`int32`. `segment_ids.shape` must be a prefix of `data.shape`.
`segment_ids` is not required to be sorted.
num_segments: An `int32` or `int64` scalar.
name: A name prefix for the returned tensor (optional).
Returns:
A `RaggedTensor` containing the aggregated values. The returned tensor
has the same dtype as `data`, and its shape is
`[num_segments] + data.shape[segment_ids.rank:]`.
Raises:
ValueError: If segment_ids.shape is not a prefix of data.shape.
"""
if not (ragged_tensor.is_ragged(data) or
ragged_tensor.is_ragged(segment_ids)):
return unsorted_segment_op(data, segment_ids, num_segments, name)
with ops.name_scope(name, 'RaggedSegment',
[data, segment_ids, num_segments]) as name:
data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
segment_ids = ragged_tensor.convert_to_tensor_or_ragged_tensor(
segment_ids, name='segment_ids')
if ragged_tensor.is_ragged(segment_ids):
if not ragged_tensor.is_ragged(data):
raise ValueError('segment_ids.shape must be a prefix of data.shape, '
'but segment_ids is ragged and data is not.')
check_splits = check_ops.assert_equal(
segment_ids.row_splits,
data.row_splits,
message='segment_ids.shape must be a prefix of data.shape')
with ops.control_dependencies([check_splits]):
return _ragged_segment_aggregate(unsorted_segment_op, data.values,
segment_ids.values, num_segments, name)
segment_ids = math_ops.cast(segment_ids, dtypes.int64)
# Find the length of each row in data. (dtype=int64, shape=[data_nrows])
data_row_lengths = data.row_splits[1:] - data.row_splits[:-1]
# Find the length that each output row will have. The length of the row
# corresponding to segment `id` is `max(data_row_lengths[i])` where
# `segment_ids[i]=id`. (dtype=int64, shape=[output_nrows])
output_row_lengths = math_ops.maximum(
math_ops.unsorted_segment_max(data_row_lengths, segment_ids,
num_segments), 0)
assert output_row_lengths.dtype == dtypes.int64
# Build the splits tensor for the output RaggedTensor.
output_splits = array_ops.concat([
array_ops.zeros([1], dtypes.int64),
math_ops.cumsum(output_row_lengths)
],
axis=0)
# For each row in `data`, find the start & limit position where that row's
# values will be aggregated in output.values.
data_row_to_out_row_start = array_ops.gather(output_splits, segment_ids)
data_row_to_out_row_limit = data_row_to_out_row_start + data_row_lengths
# For each value in `data.values`, find the position where it will
# aggregated in `output.values`.
# Get the target output values index for each data values index.
data_val_to_out_val_index = range(data_row_to_out_row_start,
data_row_to_out_row_limit).values
# Recursively aggregate the values.
output_values = _ragged_segment_aggregate(unsorted_segment_op, data.values,
data_val_to_out_val_index,
output_splits[-1])
return ragged_tensor.RaggedTensor.from_row_splits(output_values,
output_splits)
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:94,代码来源:ragged_math_ops.py
注:本文中的tensorflow.python.ops.ragged.ragged_tensor.convert_to_tensor_or_ragged_tensor函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论