Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
266 views
in Technique[技术] by (71.8m points)

python - Why doesn't my simple pytorch network work on GPU device?

I built a simple network from a tutorial and I got this error:

RuntimeError: Expected object of type torch.cuda.FloatTensor but found type torch.FloatTensor for argument #4 'mat1'

Any help? Thank you!

import torch
import torchvision

device = torch.device("cuda:0")
root = '.data/'

dataset = torchvision.datasets.MNIST(root, transform=torchvision.transforms.ToTensor(), download=True)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=4)


class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.out = torch.nn.Linear(28*28, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.out(x)
        return x

net = Net()
net.to(device)

for i, (inputs, labels) in enumerate(dataloader):
    inputs.to(device)
    out = net(inputs)
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

TL;DR
This is the fix

inputs = inputs.to(device)  

Why?!
There is a slight difference between torch.nn.Module.to() and torch.Tensor.to(): while Module.to() is an in-place operator, Tensor.to() is not. Therefore

net.to(device)

Changes net itself and moves it to device. On the other hand

inputs.to(device)

does not change inputs, but rather returns a copy of inputs that resides on device. To use that "on device" copy, you need to assign it into a variable, hence

inputs = inputs.to(device)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...