本文整理汇总了Python中tensorflow.python.ops.ragged.ragged_tensor.is_ragged函数的典型用法代码示例。如果您正苦于以下问题:Python is_ragged函数的具体用法?Python is_ragged怎么用?Python is_ragged使用的例子?那么恭喜您, 这里精选的函数代码示例或许可以为您提供帮助。
在下文中一共展示了is_ragged函数的20个代码示例,这些例子默认根据受欢迎程度排序。您可以为喜欢或者感觉有用的代码点赞,您的评价将有助于我们的系统推荐出更棒的Python代码示例。
示例1: _broadcast_elementwise_args
def _broadcast_elementwise_args(elementwise_args):
"""Broadcasts the values of `elementwise_args` to have compatible shapes.
Args:
elementwise_args: A dictionary whose keys are potentially ragged tensors.
Returns:
A tuple `(broadcast_args, broadcast_splits, checks)` where:
* `broadcast_args` is a dictionary with the same keys as
`elementwise_args`, mapping to broadcasted tensors.
* `broadcast_splits` is the broadcasted nested row splits.
* `checks` is a possibly empty tuple of assertion operations that should
be added as control dependencies.
Raises:
ValueError: If broadcasting fails.
"""
# No elementwise arguments were used: nothing to do!
if not elementwise_args:
return elementwise_args, (), ()
# A single elementwise argument was used: no broadcasting necessary.
if len(elementwise_args) == 1:
arg = list(elementwise_args.values())[0]
if ragged_tensor.is_ragged(arg):
return elementwise_args, arg.nested_row_splits, ()
else:
return elementwise_args, (), ()
# Multiple elementwise arguments.
else:
is_ragged = [ragged_tensor.is_ragged(t) for t in elementwise_args.values()]
if not any(is_ragged):
return elementwise_args, (), ()
# Support limited broadcasting (namely, scalar + ragged). Full
# broadcasting support will be added later.
if all((ragged_tensor.is_ragged(t) or t.shape.ndims == 0)
for t in elementwise_args.values()):
nested_splits_lists = [
t.nested_row_splits
for t in elementwise_args.values()
if ragged_tensor.is_ragged(t)
]
if len(nested_splits_lists) == 1:
checks = ()
else:
if any(t.shape.ndims is None for t in elementwise_args.values()):
raise ValueError('Ragged elementwise ops require that rank (number '
'of dimensions) be statically known.')
if len(set(t.shape.ndims for t in elementwise_args.values())) != 1:
raise ValueError('Ragged elementwise ops do not support '
'broadcasting yet')
checks = ragged_util.assert_splits_match(nested_splits_lists)
return (elementwise_args, nested_splits_lists[0], checks)
else:
raise ValueError('Ragged elementwise ops do not support broadcasting yet')
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:58,代码来源:ragged_elementwise_ops.py
示例2: assertRaggedAlmostEqual
def assertRaggedAlmostEqual(self, a, b, places=7):
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank)
开发者ID:aritratony,项目名称:tensorflow,代码行数:9,代码来源:ragged_test_util.py
示例3: _unicode_decode
def _unicode_decode(input, input_encoding, errors, replacement_char,
replace_control_characters, with_offsets):
"""Decodes each string into a sequence of codepoints."""
input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input")
input_ndims = input.shape.ndims
if input_ndims is None:
raise ValueError("Rank of `input` must be statically known.")
if input_ndims > 1:
# Convert to a ragged tensor with ragged_rank = input_ndims - 1.
if not ragged_tensor.is_ragged(input):
input = ragged_tensor.RaggedTensor.from_tensor(
input, ragged_rank=input_ndims - 1)
elif input.ragged_rank < input_ndims - 1:
input = input.with_flat_values(
ragged_tensor.RaggedTensor.from_tensor(
input.flat_values,
ragged_rank=input_ndims - input.ragged_rank + 1))
# Reshape the input to a flat vector, and apply the gen_string_ops op.
if ragged_tensor.is_ragged(input):
flat_input = array_ops.reshape(input.flat_values, [-1])
else:
flat_input = array_ops.reshape(input, [-1])
if with_offsets:
decode_op = gen_string_ops.unicode_decode_with_offsets
else:
decode_op = gen_string_ops.unicode_decode
flat_result = decode_op(
input=flat_input,
input_encoding=input_encoding,
errors=errors,
replacement_char=replacement_char,
replace_control_characters=replace_control_characters)
if input_ndims == 0:
codepoints = flat_result.char_values
if with_offsets:
offsets = flat_result.char_to_byte_starts
else:
codepoints = ragged_tensor.RaggedTensor.from_row_splits(
flat_result.char_values, flat_result.row_splits, validate=False)
if input_ndims > 1:
codepoints = input.with_flat_values(codepoints)
if with_offsets:
offsets = ragged_tensor.RaggedTensor.from_row_splits(
flat_result.char_to_byte_starts, flat_result.row_splits,
validate=False)
if input_ndims > 1:
offsets = input.with_flat_values(offsets)
if with_offsets:
return codepoints, offsets
else:
return codepoints
开发者ID:aritratony,项目名称:tensorflow,代码行数:56,代码来源:ragged_string_ops.py
示例4: assertRaggedEqual
def assertRaggedEqual(self, a, b):
"""Asserts that two potentially ragged tensors are equal."""
a_list = self._GetPyList(a)
b_list = self._GetPyList(b)
self.assertEqual(a_list, b_list)
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
self.assertEqual(a_ragged_rank, b_ragged_rank)
开发者ID:aritratony,项目名称:tensorflow,代码行数:10,代码来源:ragged_test_util.py
示例5: handle
def handle(self, args, kwargs):
# Extract the binary args.
if len(args) > 1:
x = args[0]
y = args[1]
args = args[2:]
elif args:
kwargs = kwargs.copy()
x = args[0]
y = kwargs.pop(self._y, None)
args = args[1:]
else:
kwargs = kwargs.copy()
x = kwargs.pop(self._x, None)
y = kwargs.pop(self._y, None)
# Bail if we don't have at least one ragged argument.
x_is_ragged = ragged_tensor.is_ragged(x)
y_is_ragged = ragged_tensor.is_ragged(y)
if not (x_is_ragged or y_is_ragged):
return self.NOT_SUPPORTED
# Convert args to tensors. Bail if conversion fails.
try:
if not x_is_ragged:
x = ops.convert_to_tensor(x, name=self._x, preferred_dtype=y.dtype)
if not y_is_ragged:
y = ops.convert_to_tensor(y, name=self._y, preferred_dtype=x.dtype)
except (TypeError, ValueError):
return self.NOT_SUPPORTED
if x_is_ragged and y_is_ragged:
x, y = ragged_tensor.match_row_splits_dtypes(x, y)
if ((x_is_ragged and y_is_ragged) or
(x_is_ragged and x.flat_values.shape.ndims <= y.shape.ndims) or
(y_is_ragged and y.flat_values.shape.ndims <= x.shape.ndims)):
bcast_shape = ragged_tensor_shape.broadcast_dynamic_shape(
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(x),
ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(y))
x = ragged_tensor_shape.broadcast_to(
x, bcast_shape, broadcast_inner_dimensions=False)
y = ragged_tensor_shape.broadcast_to(
y, bcast_shape, broadcast_inner_dimensions=False)
x_values = x.flat_values if ragged_tensor.is_ragged(x) else x
y_values = y.flat_values if ragged_tensor.is_ragged(y) else y
mapped_values = self._original_op(x_values, y_values, *args, **kwargs)
if ragged_tensor.is_ragged(x):
return x.with_flat_values(mapped_values)
else:
return y.with_flat_values(mapped_values)
开发者ID:aritratony,项目名称:tensorflow,代码行数:52,代码来源:ragged_dispatch.py
示例6: rank
def rank(input, name=None): # pylint: disable=redefined-builtin
"""Returns the rank of a RaggedTensor.
Returns a 0-D `int32` `Tensor` representing the rank of `input`.
For example:
```python
# shape of tensor 't' is [2, None, None]
t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]])
tf.rank(t) # 3
```
Args:
input: A `RaggedTensor`
name: A name for the operation (optional).
Returns:
A `Tensor` of type `int32`.
"""
with ops.name_scope(name, 'RaggedRank', [input]) as name:
if not ragged_tensor.is_ragged(input):
return array_ops.rank(input, name)
return input.ragged_rank + array_ops.rank(input.flat_values)
开发者ID:aritratony,项目名称:tensorflow,代码行数:25,代码来源:ragged_array_ops.py
示例7: normalize_tensors
def normalize_tensors(tensors):
"""Converts a nested structure of tensor-like objects to tensors.
* `SparseTensor`-like inputs are converted to `SparseTensor`.
* `TensorArray` inputs are passed through.
* Everything else is converted to a dense `Tensor`.
Args:
tensors: A nested structure of tensor-like, list,
`SparseTensor`, `SparseTensorValue`, or `TensorArray` objects.
Returns:
A nested structure of tensor, `SparseTensor`, or `TensorArray` objects.
"""
flat_tensors = nest.flatten(tensors)
prepared = []
with ops.name_scope("normalize_tensors"):
for i, t in enumerate(flat_tensors):
if sparse_tensor_lib.is_sparse(t):
prepared.append(sparse_tensor_lib.SparseTensor.from_value(t))
elif ragged_tensor.is_ragged(t):
prepared.append(
ragged_tensor.convert_to_tensor_or_ragged_tensor(
t, name="component_%d" % i))
elif isinstance(t, tensor_array_ops.TensorArray):
prepared.append(t)
else:
prepared.append(ops.convert_to_tensor(t, name="component_%d" % i))
return nest.pack_sequence_as(tensors, prepared)
开发者ID:aritratony,项目名称:tensorflow,代码行数:29,代码来源:structure.py
示例8: _ragged_tile_axis
def _ragged_tile_axis(rt_input, axis, repeats, row_splits_dtype):
"""Tile a dimension of a RaggedTensor to match a ragged shape."""
assert axis > 0 # Outermost dimension may not be ragged.
if not ragged_tensor.is_ragged(rt_input):
rt_input = ragged_tensor.RaggedTensor.from_tensor(
rt_input, ragged_rank=1, row_splits_dtype=row_splits_dtype)
if axis > 1:
return rt_input.with_values(
_ragged_tile_axis(rt_input.values, axis - 1, repeats,
row_splits_dtype))
else:
src_row_splits = rt_input.nested_row_splits
src_row_lengths = rt_input.nested_row_lengths()
splits = src_row_splits[0]
dst_row_lengths = [repeats]
for i in range(1, len(src_row_lengths)):
dst_row_lengths.append(
ragged_util.repeat_ranges(src_row_lengths[i], splits, repeats))
splits = array_ops.gather(src_row_splits[i], splits)
dst_values = ragged_util.repeat_ranges(rt_input.flat_values, splits,
repeats)
return ragged_tensor.RaggedTensor.from_nested_row_lengths(
dst_values, dst_row_lengths, validate=False)
开发者ID:aritratony,项目名称:tensorflow,代码行数:26,代码来源:ragged_tensor_shape.py
示例9: _eval_tensor
def _eval_tensor(self, tensor):
if ragged_tensor.is_ragged(tensor):
return ragged_tensor_value.RaggedTensorValue(
self._eval_tensor(tensor.values),
self._eval_tensor(tensor.row_splits))
else:
return test_util.TensorFlowTestCase._eval_tensor(self, tensor)
开发者ID:aritratony,项目名称:tensorflow,代码行数:7,代码来源:ragged_test_util.py
示例10: _replace_ragged_with_flat_values
def _replace_ragged_with_flat_values(value, nested_splits_lists):
"""Replace RaggedTensors with their flat_values, and record their splits.
Returns a copy of `value`, with any nested `RaggedTensor`s replaced by their
`flat_values` tensor. Looks inside lists, tuples, and dicts.
Appends each `RaggedTensor`'s `nested_splits` to `nested_splits_lists`.
Args:
value: The value that should be transformed by replacing `RaggedTensors`.
nested_splits_lists: An output parameter used to record the `nested_splits`
for any `RaggedTensors` that were replaced.
Returns:
A copy of `value` with nested `RaggedTensors` replaced by their `values`.
"""
# Base case
if ragged_tensor.is_ragged(value):
value = ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
nested_splits_lists.append(value.nested_row_splits)
return value.flat_values
# Recursion cases
def recurse(v):
return _replace_ragged_with_flat_values(v, nested_splits_lists)
if isinstance(value, list):
return [recurse(v) for v in value]
elif isinstance(value, tuple):
return tuple(recurse(v) for v in value)
elif isinstance(value, dict):
return dict((k, recurse(v)) for (k, v) in value.items())
else:
return value
开发者ID:aritratony,项目名称:tensorflow,代码行数:34,代码来源:ragged_functional_ops.py
示例11: ragged_op
def ragged_op(*args, **kwargs):
"""Ragged version of `op`."""
args = list(args)
# Collect all of the elementwise arguments, and put them in a single
# dict whose values are the (potentially ragged) tensors that need to
# be broadcast to a common shape. The keys of this dict are tuples
# (argkey, index), where argkey is an int for poitional args or a string
# for keyword args; and index is None for non-list args and the index of the
# tensor for list args.
elementwise_args = {}
for (name, position, is_list) in elementwise_arg_infos.values():
if position < len(args):
if is_list:
args[position] = list(args[position])
for (index, arg) in enumerate(args[position]):
elementwise_args[position, index] = arg
else:
elementwise_args[position, None] = args[position]
elif name in kwargs:
if is_list:
kwargs[name] = list(kwargs[name])
for (i, arg) in enumerate(kwargs[name]):
elementwise_args[name, i] = arg
else:
elementwise_args[name, None] = kwargs[name]
with ops.name_scope(None, op.__name__, elementwise_args.values()):
# Convert all inputs to tensors or ragged tensors.
for ((key, index), tensor) in elementwise_args.items():
argname = elementwise_arg_infos[key].name
converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
tensor, name=argname)
elementwise_args[key, index] = converted
# Broadcast tensors to have compatible shapes.
broadcast_args, result_splits, broadcast_check_ops = \
_broadcast_elementwise_args(elementwise_args)
# Replace tensor arguments with their dense values.
for ((key, index), tensor) in broadcast_args.items():
if ragged_tensor.is_ragged(tensor):
if isinstance(key, int) and index is None:
args[key] = tensor.inner_values
elif isinstance(key, int) and index is not None:
args[key][index] = tensor.inner_values
elif isinstance(key, str) and index is None:
kwargs[key] = tensor.inner_values
else:
assert isinstance(key, str) and index is not None
kwargs[key][index] = tensor.inner_values
# Call the elementwise op on the broadcasted dense values.
with ops.control_dependencies(broadcast_check_ops):
result_values = op(*args, **kwargs)
# Restore any ragged dimensions that we stripped off, and return the
# result.
return ragged_factory_ops.from_nested_row_splits(result_values,
result_splits)
开发者ID:aeverall,项目名称:tensorflow,代码行数:60,代码来源:ragged_elementwise_ops.py
示例12: assertDatasetsEqual
def assertDatasetsEqual(self, dataset1, dataset2):
"""Checks that datasets are equal. Supports both graph and eager mode."""
self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with(
dataset_ops.get_structure(dataset2)))
self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with(
dataset_ops.get_structure(dataset1)))
flattened_types = nest.flatten(
dataset_ops.get_legacy_output_types(dataset1))
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
while True:
try:
op1 = self.evaluate(next1())
except errors.OutOfRangeError:
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next2())
break
op2 = self.evaluate(next2())
op1 = nest.flatten(op1)
op2 = nest.flatten(op2)
assert len(op1) == len(op2)
for i in range(len(op1)):
if sparse_tensor.is_sparse(op1[i]):
self.assertSparseValuesEqual(op1[i], op2[i])
elif ragged_tensor.is_ragged(op1[i]):
self.assertRaggedEqual(op1[i], op2[i])
elif flattened_types[i] == dtypes.string:
self.assertAllEqual(op1[i], op2[i])
else:
self.assertAllClose(op1[i], op2[i])
开发者ID:aritratony,项目名称:tensorflow,代码行数:33,代码来源:test_base.py
示例13: reduce_mean
def reduce_mean(input_tensor, axis=None, keepdims=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceMean', [input_tensor, axis]):
total = reduce_sum(input_tensor, axis, keepdims)
if ragged_tensor.is_ragged(input_tensor):
ones = ragged_tensor.RaggedTensor.from_nested_row_splits(
array_ops.ones_like(input_tensor.flat_values),
input_tensor.nested_row_splits)
else:
ones = array_ops.ones_like(input_tensor)
count = reduce_sum(ones, axis, keepdims)
if ragged_tensor.is_ragged(total):
return ragged_tensor.RaggedTensor.from_nested_row_splits(
total.flat_values / count.flat_values, total.nested_row_splits)
else:
return total / count
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:16,代码来源:ragged_math_ops.py
示例14: reduce_mean
def reduce_mean(rt_input, axis=None, name=None):
"""For docs, see: _RAGGED_REDUCE_DOCSTRING."""
with ops.name_scope(name, 'RaggedReduceMean', [rt_input, axis]):
total = reduce_sum(rt_input, axis)
if ragged_tensor.is_ragged(rt_input):
ones = ragged_factory_ops.from_nested_row_splits(
array_ops.ones_like(rt_input.inner_values),
rt_input.nested_row_splits)
else:
ones = array_ops.ones_like(rt_input)
count = reduce_sum(ones, axis)
if ragged_tensor.is_ragged(total):
return ragged_factory_ops.from_nested_row_splits(
total.inner_values / count.inner_values, total.nested_row_splits)
else:
return total / count
开发者ID:JonathanRaiman,项目名称:tensorflow,代码行数:16,代码来源:ragged_math_ops.py
示例15: testToBatchedTensorList
def testToBatchedTensorList(self, value_fn, element_0_fn):
batched_value = value_fn()
s = structure.Structure.from_value(batched_value)
batched_tensor_list = s._to_batched_tensor_list(batched_value)
# The batch dimension is 2 for all of the test cases.
# NOTE(mrry): `tf.shape()` does not currently work for the DT_VARIANT
# tensors in which we store sparse tensors.
for t in batched_tensor_list:
if t.dtype != dtypes.variant:
self.assertEqual(2, self.evaluate(array_ops.shape(t)[0]))
# Test that the 0th element from the unbatched tensor is equal to the
# expected value.
expected_element_0 = self.evaluate(element_0_fn())
unbatched_s = s._unbatch()
actual_element_0 = unbatched_s._from_tensor_list(
[t[0] for t in batched_tensor_list])
for expected, actual in zip(
nest.flatten(expected_element_0), nest.flatten(actual_element_0)):
if sparse_tensor.is_sparse(expected):
self.assertSparseValuesEqual(expected, actual)
elif ragged_tensor.is_ragged(expected):
self.assertRaggedEqual(expected, actual)
else:
self.assertAllEqual(expected, actual)
开发者ID:aritratony,项目名称:tensorflow,代码行数:27,代码来源:structure_test.py
示例16: to_sparse
def to_sparse(rt_input, name=None):
"""Converts a `RaggedTensor` into a sparse tensor.
Example:
```python
>>> rt = ragged.constant([[1, 2, 3], [4], [], [5, 6]])
>>> ragged.to_sparse(rt).eval()
SparseTensorValue(indices=[[0, 0], [0, 1], [0, 2], [1, 0], [3, 0], [3, 1]],
values=[1, 2, 3, 4, 5, 6],
dense_shape=[4, 3])
```
Args:
rt_input: The input `RaggedTensor`.
name: A name prefix for the returned tensors (optional).
Returns:
A SparseTensor with the same values as `rt_input`.
"""
if not ragged_tensor.is_ragged(rt_input):
raise TypeError('Expected RaggedTensor, got %s' % type(rt_input).__name__)
with ops.name_scope(name, 'RaggedToSparse', [rt_input]):
rt_input = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
rt_input, name='rt_input')
result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
rt_input.nested_row_splits, rt_input.inner_values, name=name)
return sparse_tensor.SparseTensor(
result.sparse_indices, result.sparse_values, result.sparse_dense_shape)
开发者ID:abhinav-upadhyay,项目名称:tensorflow,代码行数:29,代码来源:ragged_conversion_ops.py
示例17: eval_to_list
def eval_to_list(self, tensor):
value = self.evaluate(tensor)
if ragged_tensor.is_ragged(value):
return value.to_list()
elif isinstance(value, np.ndarray):
return value.tolist()
else:
return value
开发者ID:aritratony,项目名称:tensorflow,代码行数:8,代码来源:ragged_test_util.py
示例18: testFromTensorSlicesMixedRagged
def testFromTensorSlicesMixedRagged(self):
components = (np.tile(np.array([[1], [2], [3]]),
20), np.tile(np.array([[12], [13], [14]]),
22), np.array([37.0, 38.0, 39.0]),
sparse_tensor.SparseTensorValue(
indices=np.array([[0, 0], [1, 0], [2, 0]]),
values=np.array([0, 0, 0]),
dense_shape=np.array([3, 1])),
sparse_tensor.SparseTensorValue(
indices=np.array([[0, 0], [1, 1], [2, 2]]),
values=np.array([1, 2, 3]),
dense_shape=np.array([3, 3])),
ragged_factory_ops.constant_value([[[0]], [[1]], [[2]]]))
dataset = dataset_ops.Dataset.from_tensor_slices(components)
get_next = self.getNext(dataset)
expected = [
(sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
values=np.array([0]),
dense_shape=np.array([1])),
sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
values=np.array([1]),
dense_shape=np.array([3])), ragged_factory_ops.constant_value([[0]
])),
(sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
values=np.array([0]),
dense_shape=np.array([1])),
sparse_tensor.SparseTensorValue(
indices=np.array([[1]]),
values=np.array([2]),
dense_shape=np.array([3])), ragged_factory_ops.constant_value([[1]
])),
(sparse_tensor.SparseTensorValue(
indices=np.array([[0]]),
values=np.array([0]),
dense_shape=np.array([1])),
sparse_tensor.SparseTensorValue(
indices=np.array([[2]]),
values=np.array([3]),
dense_shape=np.array([3])), ragged_factory_ops.constant_value([[2]
])),
]
for i in range(3):
results = self.evaluate(get_next())
for component, result_component in zip(
(list(zip(*components[:3]))[i] + expected[i]), results):
if sparse_tensor.is_sparse(component):
self.assertSparseValuesEqual(component, result_component)
elif ragged_tensor.is_ragged(component):
self.assertRaggedEqual(component, result_component)
else:
self.assertAllEqual(component, result_component)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
开发者ID:aritratony,项目名称:tensorflow,代码行数:58,代码来源:from_tensor_slices_test.py
示例19: _increase_ragged_rank_to
def _increase_ragged_rank_to(rt_input, ragged_rank):
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
if ragged_rank > 0:
if not ragged_tensor.is_ragged(rt_input):
rt_input = ragged_conversion_ops.from_tensor(rt_input)
if rt_input.ragged_rank < ragged_rank:
rt_input = rt_input.with_values(
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1))
return rt_input
开发者ID:Wajih-O,项目名称:tensorflow,代码行数:9,代码来源:ragged_concat_ops.py
示例20: _increase_ragged_rank_to
def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype):
"""Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
if ragged_rank > 0:
if not ragged_tensor.is_ragged(rt_input):
rt_input = ragged_tensor.RaggedTensor.from_tensor(
rt_input, row_splits_dtype=row_splits_dtype)
if rt_input.ragged_rank < ragged_rank:
rt_input = rt_input.with_values(
_increase_ragged_rank_to(rt_input.values, ragged_rank - 1,
row_splits_dtype))
return rt_input
开发者ID:aritratony,项目名称:tensorflow,代码行数:11,代码来源:ragged_concat_ops.py
注:本文中的tensorflow.python.ops.ragged.ragged_tensor.is_ragged函数示例由纯净天空整理自Github/MSDocs等源码及文档管理平台,相关代码片段筛选自各路编程大神贡献的开源项目,源码版权归原作者所有,传播和使用请参考对应项目的License;未经允许,请勿转载。 |
请发表评论