Environment: PyTorch 1.7.1, CUDA 11.0, RTX 2080 TI.
TL;DR: Transpose + 2D conv is faster (in this environment, and for the tested data shapes).
Code (modified from here):
import torch
import torch.nn as nn
import time
b = 4
c = 64
t = 4
h = 256
w = 256
raw_data = torch.randn(b, c, t, h, w).cuda()
def time2D():
conv2d = nn.Conv2d(c, c, kernel_size=3, padding=1).cuda()
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
data = raw_data.transpose(1,2).reshape(b*t, c, h, w).detach()
out = conv2d(data)
out = out.view(b, t, c, h, w).transpose(1, 2).contiguous()
out.mean().backward()
torch.cuda.synchronize()
end = time.time()
print(" --- %s --- " %(end - start))
def time3D():
conv3d = nn.Conv3d(c, c, kernel_size=(1,3,3), padding=(0,1,1)).cuda()
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
out = conv3d(raw_data.detach())
out.mean().backward()
torch.cuda.synchronize()
end = time.time()
print(" --- %s --- " %(end - start))
print("Initializing cuda state")
time2D()
print("going to time2D")
time2D()
print("going to time3D")
time3D()
For shape = 4*64*4*256*256
:
2D: 1.8675172328948975
3D: 4.384545087814331
For shape = 8*512*16*64*64
:
2D: 37.95961904525757
3D: 49.730860471725464
For shape = 4*128*128*16*16
:
2D: 0.6455907821655273
3D: 1.8380646705627441
与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…