The tensor is A_tensor,has shape [3,3,3] and I want get the last axis's value by index.
A_tensor
[3,3,3]
How do I do that in Tensorflow?
A_tensor =tf.constant([[1,2,3],[2,3,4],[3,4,5]])
How to get tensor ([[1,2],[2,3],[3,4]])?
([[1,2],[2,3],[3,4]])
You can use tf.gather():
tf.gather()
A_tensor =tf.constant([[1,2,3],[2,3,4],[3,4,5]]) tf.gather(A_tensor,[[0,1]],axis=-1)
2.1m questions
2.1m answers
60 comments
57.0k users