原因
Torch有两个版本,一个就叫Torch一个专门给Python用的Pytorch,它们训练完之后保存下来的模型是不一样的.
说到这问题就很清楚了.OpenCV的ReadNetFromTorch 支持的是前者...
解决方法
那么有没有解决办法呢,答案是有的.
PyTorch支持把模型保存为ONNX格式.而这个格式在opencv是支持的.
操作如下:
import torch
import torch.onnx
from torch.autograd import Variable
# ~~~~~~~~~~~~~~~~初始化与训练模型过程~~~~~~~~~~~~~
# 这是普通的pytorch模型保存方式:
torch.save(net.state_dict(), "torch.pt")
# 这是保存为ONNX的方法:
# 由于PyTorch的模型,是动态调整大小的,这里需要初始化一个指定格式的数据,用来调整模型大小
# 就是和你训练模型的时候用的数据一样的格式就行
dummy_input = Variable(torch.randn(1, 1, 28, 28)).to(device)
# 保存模型
torch.onnx.export(net, dummy_input, "torch.onnx")
注意,这里还有个坑!
虽然模型保存成了ONNX格式 ,但是OpenCV的ReadTensorFromONNX 并不能加载! 需要用ReadNet 方法加载!
|
请发表评论