Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
1.4k views
in Technique[技术] by (71.8m points)

pytorch - why torch.Tensor subtract works well when tensor size is different?

This example will make it easier to understand. The following fails:

A = tensor.torch([[1, 2, 3], [4, 5, 6]])   # shape : (2, 3)
B = tensor.torch([[1, 2], [3, 4], [5, 6]]) # shape : (3, 2)
print((A - B).shape)

# RuntimeError: The size of tensor A (3) must match the size of tensor B (2) at non-singleton dimension 1
# ==================================================================
A = tensor.torch([[1, 2], [3, 4], [5, 6]])   # shape : (3, 2)
B = tensor.torch([[1, 2], [3, 4],]) # shape : (2, 2)
print((A - B).shape)

# RuntimeError: The size of tensor A (3) must match the size of tensor B (2) at non-singleton dimension 0

But the following works well:

a = torch.ones(8).unsqueeze(0).unsqueeze(-1).expand(4, 8, 7) 
a_temp = a.unsqueeze(2)                            # shape : ( 4, 8, 1, 7 )
b_temp = torch.transpose(a_temp, 1, 2)             # shape : ( 4, 1, 8, 7 )
print(a_temp-b_temp)                               # shape : ( 4, 8, 8, 7 )

Why does the latter work, but not the former?
How/why has the result shape been expanded?

question from:https://stackoverflow.com/questions/66059474/why-torch-tensor-subtract-works-well-when-tensor-size-is-different

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

This is well explained by the broadcasting semantics. The important part is :

Two tensors are “broadcastable” if the following rules hold:

  • Each tensor has at least one dimension.
  • When iterating over the dimension sizes, starting at the trailing dimension, the dimension sizes must either be equal, one of them is 1, or one of them does not exist.

In your case, (3,2) and (2,3) cannot be broadcast to a common shape (3 != 2 and neither are equal to 1), but (4,8,1,7), (4,1,8,7) and (4,8,8,7) are broadcast compatible.

This is basically what the error states : all dimensions must be either equal ("match") or singletons (i.e equal to 1)

What happens when the shape are broadcasted is basically a tensor expansion to make the shape match (expand to [4,8,8,7]), and then perform the subtraction as usual. Expansion duplicates your data (in a smart way) to reach the required shape.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...