If I have a N-dimensional tensor, I would like to create another tensor (with the same shape) of values 0 and 1, where 1 is in the same position as the maximum element in original tensor across some dimension.
One constraint I have is that I want to get only the first maximum element along that axis, in case there are duplicates.
For simplification, I will use fewer dimensions.
>>> x = tf.constant([[7, 2, 3],
[5, 0, 1],
[3, 8, 2]], dtype=tf.float32)
>>> tf.reduce_max(x, axis=-1)
tf.Tensor([7. 5. 8.], shape=(3,), dtype=float32)
What I want is:
tf.Tensor([1. 0. 0.],
[1. 0. 0.],
[0. 1. 0.], shape=(3,3), dtype=float32)
What I've tried (and realized was wrong):
>>> tf.cast(tf.equal(x, tf.reduce_max(x, axis=-1, keepdims=True)), dtype=tf.float32)
# works fine when there are no duplicates
tf.Tensor([[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]], shape=(3, 3), dtype=float32)
>>> y = tf.zeros([3,3])
>>> tf.cast(tf.equal(y, tf.reduce_max(y, axis=-1, keepdims=True)), dtype=tf.float32)
# fails when there are multiple identical values across dimension
tf.Tensor([[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]], shape=(3, 3), dtype=float32)
Edit: Solved
tf.cast(tf.equal(tf.argsort(tf.argsort(x, 1, direction='DESCENDING'), 1), 0), tf.float32)
question from:
https://stackoverflow.com/questions/66048913/tensorflow-binary-mask-of-max-values-along-tensor-axis