1 小时入门 PyTorch:从张量到多 GPU 神经网络训练
Sebastian Raschka(极客时间编译)
ML/AI Research Engineer
32 人已学习
免费领取
1 小时入门 PyTorch:从张量到多 GPU 神经网络训练
15
15
1.0x
00:00/00:00
登录|注册

08|保存和加载模型

在前一节中,我们成功训练了一个模型。现在让我们看看如何保存训练好的模型以便以后重用。
以下是我们在 PyTorch 中保存和加载模型的推荐方法:
torch.save(model.state_dict(), "model.pth")
模型的 state_dict 是一个 Python 字典对象,它将模型中的每个层映射到其可训练参数(权重和偏置)。请注意,"model.pth" 是保存到磁盘的模型文件的任意文件名。我们可以为其指定任何喜欢的名称和文件后缀,不过 .pth.pt 是最常见的命名惯例。
一旦我们保存了模型,就可以按照以下方式从磁盘恢复它:
model = NeuralNetwork(2, 2) # needs to match the original model exactly
model.load_state_dict(torch.load("model.pth", weights_only=True))
<All keys matched successfully>
torch.load("model.pth")  函数会读取文件  "model.pth",并重建包含模型参数的 Python 字典对象,而  model.load_state_dict()  则将这些参数应用到模型上,从而有效地恢复我们保存时的训练状态。
确认放弃笔记?
放弃后所记笔记将不保留。
新功能上线,你的历史笔记已初始化为私密笔记,是否一键批量公开?
批量公开的笔记不会为你同步至部落
公开
同步至部落
取消
完成
0/2000
荧光笔
直线
曲线
笔记
复制
AI
  • 深入了解
  • 翻译
    • 英语
    • 中文简体
    • 法语
    • 德语
    • 日语
    • 韩语
    • 俄语
    • 西班牙语
  • 解释
  • 总结

1. 在PyTorch中保存和加载模型的推荐方法是使用`torch.save(model.state_dict(), "model.pth")`来保存模型的状态字典。 2. 模型的`state_dict`是一个Python字典对象,将模型中的每个层映射到其可训练参数(权重和偏置)。 3. 通过`model.load_state_dict(torch.load("model.pth", weights_only=True))`可以从磁盘恢复模型,需要确保新建的网络结构与原先保存的模型架构完全一致。 4. 在保存模型的同一个会话中运行加载模型的代码时,需要先在内存中创建一个模型实例,才能加载保存的参数。

该试读文章来自《1 小时入门 PyTorch:从张量到多 GPU 神经网络训练》,如需阅读全部文章,
请先领取课程
免费领取
登录 后留言

精选留言

由作者筛选后的优质留言将会公开显示,欢迎踊跃留言。
收起评论
显示
设置
留言
收藏
沉浸
阅读
分享
手机端
快捷键
回顶部