TLDR; You can't, all sequences must have the same size along a given axis.
Take this simplified example:
>>> inputs_tokens = torch.tensor([[ 1, 101, 18, 101, 9],
[ 1, 2, 101, 101, 101]])
>>> inputs_tokens.shape
torch.Size([2, 5])
>>> cls_tokens = inputs_tokens == 101
tensor([[False, True, False, True, False],
[False, False, True, True, True]])
Indexing inputs_tokens
with the cls_tokens
mask comes down to reducing inputs_tokens
to cls_tokens
's true
values. In a general case where there is a different number of true
values per batch, keeping the shape is impossible.
Following the above example, here is seq_A
:
>>> seq_A = torch.rand(2, 5, 1)
tensor([[[0.4644],
[0.7656],
[0.3951],
[0.6384],
[0.1090]],
[[0.6754],
[0.0144],
[0.7154],
[0.5805],
[0.5274]]])
According to your example, you would expect to have an output shape of (2, N, 1)
. What would N
be? 3
? What about the first batch which only as 2 true
values? The resulting tensor can't have different sizes (2 and 3 on axis=1
). Hence: "all sequences on axis=1
must have the same size".
If however, you are expecting each batch to have the same number of tokens 101, then you could get away with a broadcast of your indexed tensor:
>>> inputs_tokens = torch.tensor([[ 1, 101, 101, 101, 9],
[ 1, 2, 101, 101, 101]])
>>> inputs_tokens.shape
>>> N = cls_tokens[0].sum()
3
Here remember, I'm assuming you have:
>>> assert all(cls_tokens.sum(axis=1) == N)
Therefore the desired output (with shape (2, 3, 1)
) is:
>>> seq_A[cls_tokens].reshape(seq_A.size(0), N, -1)
tensor([[[0.7656],
[0.3951],
[0.6384]],
[[0.7154],
[0.5805],
[0.5274]]])
Edit - if you really want to do this though you would require the use of a list comprehension:
>>> [seq_A[i, cls_tokens[i]] for i in range(cls_tokens.size(0))]
[ tensor([[0.7656],
[0.6384]]),
tensor([[0.7154],
[0.5805],
[0.5274]]) ]