怎样保存模型权重和checkpoint
作者:取个名字真难呐 发布时间:2023-04-12 00:45:00
概述
在pytorch中有两种方式可以保存推理模型,第一种是只保存模型的参数,比如parameters和buffers;另外一种是保存整个模型;
1.保存模型 - 权重参数
我们可以用torch.save()函数来保存model.state_dict();state_dict()里面包含模型的parameters&buffers;这种方法只保存模型中必要的训练参数。
你可以用pytorch中的pickle来保存模型;使用这种方法可以生成最直观的语法,并涉及最少的代码;这种方法的缺点是,序列化的数据被绑定到特定的类和保存模型时使用的确切的目录结构。
这样做的原因是pickle并不保存模型类本身。相反,它保存包含类的文件的路径,在加载期间使用;因此,当在其他项目中使用或重构后,您的代码可能以各种方式中断。
我们将探讨如何保存和加载模型进行推断的两种方法。
步骤:
(1)导入所有必要的库来加载我们的数据
(2)定义和初始化神经网络
(3)初始化优化器
(4)保存并通过state_dict加载模型
(5)保存并加载整个模型
1.1代码
# -*- coding: utf-8 -*-
# @Project: zc
# @Author: zc
# @File name: Neural_Network_test
# @Create time: 2022/3/19 15:33
# 1.导入相关数据库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
# 2.定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 3. 实例化神经网络
net = Net()
# 4. 实例化优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 5. 保存模型参数
# Specify a path
PATH = "state_dict_model.pt"
# 6. 保存模型的参数字典:parameters and buffers
torch.save(net.state_dict(), PATH)
# 7. 实例化新的模型
model = Net()
# 8. 给新的实例加载之前的模型参数
model.load_state_dict(torch.load(PATH))
# 9. 设置模型为评估模式
model.eval()
注意(1):
pytorch中常用的惯例是将model.state_dict()保存为"state_dict_model.pt",即文件的格式一般是.pt或者.pth格式文件;注意load_state_dict加载的是一个字典,而不是路径。
注意(2):
模型参数在推理阶段一定要设置model.eval();这样可以让dropout和batchnorm失效,如果没设置推理模式,会得到不一样的结果。
2.保存模型 - 整个模型
将模型所有的内容都保存下来。
# Specify a path
PATH = "entire_model.pt"
# Save
torch.save(net, PATH)
# Load
model = torch.load(PATH)
model.eval()
3.保存模型 - checkpoints
我们按照checkpoints模式来保存模型,本质上就是按照字典的模式进行分门别类的保存,我们可以通过键值进行加载。
epoch
:训练周期model_state_dict
:模型可训练参数optimizer_state_dict
:模型优化器参数loss
:模型的损失函数
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
保存和加载通用的检查点模型以进行推断或恢复训练,这有助于您从上一个地方继续进行。
当保存一个常规检查点时,您必须保存模型的state_dict之外的更多信息。
保存优化器的state_dict也很重要,因为它包含缓冲区和参数,随着模型的运行而更新。
您可能希望保存的其他项目是您离开的时期,最新记录的训练损失,外部torch.nn.嵌入层,以及更多,基于自己的算法
3.1代码
# 1.导入相关数据库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
# 2. 定义神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 3. 实例化神经网络
net = Net()
# 4. 实例化优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# Additional information
# 5. 定义超参数
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
# 6. 以checkpoints形式保存模型的相关数据
torch.save({
'epoch': EPOCH,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
# 7. 重新实例化一个模型
model = Net()
# 8. 实例化优化器
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
# 9. 加载以前的checkpoint
checkpoint = torch.load(PATH)
# 10. 通过键值来加载相关参数
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
# 11.设置推理模式
model.eval()
# - or -
model.train()
4.保存双模型
当保存有多个神经网络模型组成的神经网络时,比如GAN对抗模型,sequence-to-sequence序列到序列模型,或者一个组合模型,你必须为每一个模型保存状态字典state_dict()和其对应的优化器参数optimizer.state_dict();您还可以保存任何其他项目,可能会帮助您恢复训练,只需将它们添加到字典;为了加载模型,第一步是初始化神经网络模型和优化器,然后用torch.load()去加载checkpoint对应的数据,因为checkpoints是字典,所以我们可以通过键值进行查询导入;
4.1相关步骤
(1)导入所有相关的数据库
(2)定义和实例化神经网络模型
(3)初始化优化器
(4)保存多重模型
(5)加载多重模型
# 1.导入相关数据库
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F
# 2. 定义神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 3. 实例化神经网络A,B
netA = Net()
netB = Net()
# 4. 实例化优化器A,B
optimizerA = optim.SGD(netA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(netB.parameters(), lr=0.001, momentum=0.9)
# 5. 保存模型
# Specify a path to save to
PATH = "model.pt"
torch.save({
'modelA_state_dict': netA.state_dict(),
'modelB_state_dict': netB.state_dict(),
'optimizerA_state_dict': optimizerA.state_dict(),
'optimizerB_state_dict': optimizerB.state_dict(),
}, PATH)
# 6.重新实例化新的网络模型A,B
modelA = Net()
modelB = Net()
# 7. 重新实例化新的网络模型A,B
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
# 8. 将以前模型的参数重新加载到新的模型A,B中
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])
# 9. 开启预测模式
modelA.eval()
modelB.eval()
# - or -
# 10.开启训练模式
modelA.train()
modelB.train()
5.机器学习流程图
6.机器学习常用库
来源:https://blog.csdn.net/scar2016/article/details/123618089


猜你喜欢
- 对于个人站长来说,如何能使自己的网站与众不同、充满个性,一直是不懈努力的目标。除了尽量提高页面的视觉效果、互动功能以外,如果能在打开网页的同
- 代码如下:--程序员们在编写一个雇员报表,他们需要得到每个雇员当前及历史工资状态的信息, --以便生成报表。报表需要显示每个人的晋升日期和工
- 矩阵增加行np.row_stack() 与 np.column_stack()import numpy as npa = np.array(
- 背景客户最近有这样的需求,想通过统计Oracle数据库活跃会话数,并记录在案,利用比对历史的活跃会话的方式,实现对系统整体用户并发量有大概的
- 阅读上一篇:成为一个顶级设计师的第一准则限制你的色彩成为一个顶级设计师的7个简单原则的第二部分限制使用你的色彩。好象上个准则是让你限制用你的
- 正在看的ORACLE教程是:Oracle与SQL Server在企业应用的比较。在我供职的公司不仅仅拥有Oracle数据库,同时还拥有SQL
- 添加设置页面-add_menu_page函数add_menu_page(),这个函数是往后台添加顶级菜单先,也就是和“外观”、“插件”等一样
- 一场大雪,覆盖了华北、华东。天地连成一片,城市银装素裹,处处诗情画意、人人兴高采烈。朋友圈被雪景图和调侃路滑摔跤的段子刷屏,气氛比过年还要热
- 前言过去公司都是用的5.7 系列的MySQL,随着8.0的发版,也想试着升级一下。遇到了两个小错误,记录在此。在开始之前,如果对MySQL8
- 前言本文主要给大家总结介绍了关于Python的一些基础技巧,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧。1.starts
- Go语言也称 Golang,兼具效率、性能、安全、健壮等特性。Go语言从底层原生支持并发,无须第三方库、开发者的编程技巧和开发经验就可以轻松
- 来需求了。。干活啦。。需求内容部分时候由于缓存刷新、验证码显示不出来或者浏览器打不开或者打开速度很慢等原因,导致部分测试同事不想使用浏览器登
- 一、基本使用最近研究了一下 el-upload组件 踩了一些小坑 写起来大家学习一下很经常的一件事情 经常会去直接拷贝 elem
- PyQt5不规则窗口实现动画效果实例import sysfrom PyQt5.QtCore import *from PyQt5.QtGui
- 如何用ASP来识别操作系统是vista的?我在网上找了个函数,但是不能判断是vista系统,希望大家帮忙. 这个是我在网上找的函数: Fun
- 一、序言本文承接[Mybatis缓存体系探究],提供基于MybatisPlus技术可用于生产环境下的二级缓存解决方案。1、前置条件掌握MyB
- 数据库:30万条,有ID列但无主键,在要搜索的“分类”字段上建有非聚集索引过程T-SQL: /* 用户自定义函数:执行时间在115
- 1.事件简介事件(event)是MySQL在相应的时刻调用的过程式数据库对象。一个事件可调用一次,也可周期性的启动,它由一个特定的线程来管理
- MHA介绍MHA是一位日本MySQL大牛用Perl写的一套MySQL故障切换方案,来保证数据库系统的高可用.在宕机的时间内(通常10—30秒
- 前言众所周知字典(dict)对象是 Python 最常用的数据结构,社区曾有人开玩笑地说:"Python企图用字典装载整个世界&q