You can use tf.where
and tf.gather
:
import tensorflow as tf
import numpy as np
def map_func(a) :
return tf.gather_nd(a, tf.where(a > 0.5))
inputs = np.random.rand(10, 5)
np.round(inputs, 3)
array([[0.952, 0.329, 0.786, 0.714, 0.819],
[0.048, 0.98 , 0.363, 0.03 , 0.078],
[0.779, 0.833, 0.368, 0.216, 0.669],
[0.807, 0.332, 0.217, 0.594, 0.254],
[0.787, 0.453, 0.943, 0.915, 0.76 ],
[0.047, 0.014, 0.555, 0.57 , 0.422],
[0.195, 0.167, 0.077, 0.562, 0.586],
[0.693, 0.434, 0.055, 0.213, 0.021],
[0.459, 0.34 , 0.785, 0.938, 0.979],
[0.08 , 0.667, 0.781, 0.092, 0.644]])
ds = tf.data.Dataset.from_tensor_slices(inputs)
ds = ds.map(map_func)
for i in ds:
print(np.round(i.numpy(), 3))
[0.952 0.786 0.714 0.819]
[0.98]
[0.779 0.833 0.669]
[0.807 0.594]
[0.787 0.943 0.915 0.76 ]
[0.555 0.57 ]
[0.562 0.586]
[0.693]
[0.785 0.938 0.979]
[0.667 0.781 0.644]