If I understand correctly what you're trying to do is stack the outputted mini-batches together into a single batch. My bet is that your last batch is partially filled (only has 16 elements instead of 32).
Instead of using torch.stack
(creating a new axis), I would simply concatenate with torch.cat
on the batch axis (axis=0
). Assuming matrices
is a list of torch.Tensor
s.
torch.cat(matrices).cpu().detach().numpy()
As torch.cat
concatenates on axis=0
by default.
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…