• 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

在PyTorch中保存训练模型的最佳方法?

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

我正在寻找在PyTorch中保存训练模型的不同方法。到目前为止,发现了两种选择。

  1. torch.save()用于保存模型,torch.load()用于加载模型。

  2. model.state_dict()用于保存训练模型,model.load_state_dict()用于加载保存的模型。

我已经遇到过这个discussion,其中推荐方法2优于方法1。

我的问题是,为什么第二种方法更受欢迎?是否因为torch.nn模块具有这两个功能所以鼓励使用它们?

最佳解决思路

我在github repo上找到了this page,我只是在这里原样粘贴复制内容过来:


保存模型的推荐方法

序列化和恢复模型有两种主要方法。

第一个(推荐)保存并仅加载模型参数:

torch.save(the_model.state_dict(), PATH)

然后:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

第二个保存并加载整个模型:

torch.save(the_model, PATH)

然后:

the_model = torch.load(PATH)

但是在这种情况下,序列化数据绑定到特定的类和使用的确切目录结构,因此当在其他项目中使用时,或者在一些严重的重构之后,它可能以各种奇怪的方式中断活出问题。

次佳解决思路

这取决于你想做什么。

案例#1:保存模型以自行使用它进行预测:保存模型,恢复模型,然后将模型更改为评估模式。这样做是因为通常有BatchNormDropout图层,默认情况下在构造时处于训练模式:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

案例#2:保存模型以便以后恢复训练:如果您需要继续训练您将要保存的模型,则需要保存的不仅仅是模型。您还需要保存优化器,迭代轮次(epochs),分数等相关的状态。您可以这样做:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

要恢复训练,您可以执行以下操作:state = torch.load(filepath),然后,恢复每个对象的状态,如下所示:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

由于您正在恢复训练,因此在加载时恢复状态后请勿调用model.eval()

案例#3:其他人无法访问您的代码使用的模型:在Tensorflow中,您可以创建一个.pb文件,该文件定义了模型的体系结构和权重。这非常方便,特别是在使用Tensorflow serve时。在Pytorch中执行此操作的等效方法是:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

这种方式不太健壮,因为pytorch已经经历了很多变化,我们不太推荐这种方式。

参考资料

  • Best way to save a trained model in PyTorch?


鲜花

握手

雷人

路过

鸡蛋
专题导读
上一篇:
TensorFlow的tf.nn.max_pool中'SAME'和'VALID'填充有什么区别?发布时间:2022-05-14
下一篇:
PyTorch入门简介发布时间:2022-05-14
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap