pytorch模型的保存加载与续训练详解
作者:秃头小苏 发布时间:2022-03-27 08:45:08
前面
最近,看到不少小伙伴问pytorch如何保存和加载模型,其实这部分pytorch官网介绍的也是很清楚的,感兴趣的点击了解详情🥁🥁🥁
但是肯定有很多人是不愿意看官网的,所以我还是花一篇文章来为大家介绍介绍。当然了,在介绍中我会加入自己的一些理解,让大家有一个更深的认识。如果准备好了的话,就让我们开始吧。⏳⏳⏳
模型保存与加载
pytorch中介绍了几种不同的模型保存和加载方式,我会在下文一一为大家介绍。首先先让我们来随便定义一个模型,如下:【用的是pytorch官网的例子】
# 模型定义
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
定义好模型结构后,我们可以实例化这个模型:
#模型初始化
model = TheModelClass()
模型初始化过后,我们就一起来看看模型保存和加载的方式吧。🍄🍄🍄
方式1
方式1是官方推荐的一种方式,我们直接来看代码好了,如下:
# 保存模型
torch.save(model.state_dict(), './model/model_state_dict.pth')
该方法后面的参数'./model/model_state_dict.pth'
为模型的保存路径,模型后缀名官方推荐使用.pth
和.pt
,当然了,你取别的后缀名也是完全可行的。☘☘☘
介绍了模型的保存,下面就来看看方式1是如何加载模型的。【这里我说明一点,模型保存往往是在训练中进行的,而模型加载多数用在模型推理中,它们存在两个文件中,故我们在推理过程中要先实列化模型】
# 加载模型
model_test1 = TheModelClass() # 加载模型时应先实例化模型
# load_state_dict()函数接收一个字典,所以不能直接将'./model/model_state_dict.pth'传入,而是先使用load函数将保存的模型参数反序列化
model_test1.load_state_dict(torch.load('./model/model_state_dict.pth'))
model_test1.eval() # 模型推理时设置
在上述的代码注释中我有写到,我们使用load_state_dict()
加载模型时先需要使用load方法将保存的模型参数==反序列化==,load后的结果是一个字典,这时就可以通过load_state_dict()
方法来加载了。
这里我来简单说一下我理解的反序列化,其和序列化是相对应的一个概念。序列化就是把内存中的数据保存到磁盘中,像我们使用torch.save()
方法保存模型就是序列化;而反序列化则是将硬盘中的数据加载到内存当中,显然我们加载模型的过程就是反序列化过程。【大致的意思如下图所示,偶然在水群的时候看到一个画图软件,是不是还挺好看的🍧🍧🍧】
方式2
方式2非常简单,直接上代码:
# 保存模型
torch.save(model, './model/model.pt') #这里我们保存模型的后缀名取.pt
# 加载模型
model_test2 = torch.load('./model/model.pt')
model_test2.eval() # 模型推理时设置
但是这种方式是不推荐使用的,因为你使用这种方式保存模型,然后再加载时会遇到各种各样的错误。为了加深大家理解,我们来看这样的一个例子。文件的结构如下图所示:
models.py
文件中存储的是模型的定义,其位于文件夹models下。save_model.py
文件中写的是保存模型的代码,如下:
from models.models import TheModelClass
from torch import optim
import torch
#模型初始化
model = TheModelClass()
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# ## 保存加载方式2——save/load
# # 保存模型
# torch.save(models, './models/models.pt')
执行此文件后,会生成models.pt
文件,我们在执行load_mode.py
文件即可实现加载,load_mode.py
内容如下:
from models.models import TheModelClass
import torch
## 加载方式2
# 加载模型
model_test2 = TheModelClass()
model_test2 = torch.load('./models/models.pt')
model_test2.eval() # 模型推理时设置
print(model_test2)
此时我们可以正常加载。但如果我们将models文件夹修改为model,如下:
此时我们在使用如下代码加载模型的话就会出现错误:
from models.models import TheModelClass
import torch
## 加载方式2
# 加载模型
model_test2 = TheModelClass()
model_test2 = torch.load('./model/models.pt') #这里需要修改一下文件路径
model_test2.eval() # 模型推理时设置
print(model_test2)
出现这种错误的原因是使用方式2进行模型保存的时候会把模型结构定义文件路径记录下来,加载的时候就会根据路径解析它然后装载参数;当把模型定义文件路径修改以后,使用torch.load(path)就会报错。
其实使用方式2进行模型的保存和加载还会存在各种问题,感兴趣的可以看看这篇博文。总之,在我们今后的使用中,尽量不要用方式2来加载模型。🌱🌱🌱
方式3
pytorch还为我们提供了一种模型保存与加载的方式——checkpoint。这种方式保存的是一个字典,如果我们程序在运行中由于某种原因异常中止,那么这种方式可以很方便的让我们接着上次训练,正因为这样,我非常推荐大家使用这种方式进行模型的保存与加载。下面就让我们一起来看看方式3是如何使用的吧!!!🍥🍥🍥
首先,我们同样使用torch.save
来保存模型,但是这里保存的是一个字典,里面可以填入你需要保存的参数,如下:
# 保存checkpoint
torch.save({
'epoch':epoch,
'model_state_dict':model.state_dict(),
'optimizer_state_dict':optimizer.state_dict(),
'loss':loss
}, './model/model_checkpoint.tar' #这里的后缀名官方推荐使用.tar
)
接着我们来看看如何加载checkpoint,代码如下:
# 加载checkpoint
model_checkpoint = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load('./model/model_checkpoint.tar') # 先反序列化模型
model_checkpoint.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
看了我上文的介绍,大家是否知道如何使用checkpoint
了呢,我想大家都会觉得这个不是很难,但要自己写可能还是不好把握,那么第一次就让我来带领大家看看如何在代码中使用checkpoint
吧!!!🍵🍵🍵
这节我采用cifar10数据集实现物体分类的例子,我的这篇博文对其进行了详细介绍,那么这里介绍checkpoint
我将利用这个demo来为大家讲解。首先我们直接来看模型保存的完整代码,如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、准备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、搭建神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, input):
input = self.model1(input)
return input
#4、创建网络模型
net = Net()
#5、设置损失函数、优化器
#损失函数
loss_fun = nn.CrossEntropyLoss() #交叉熵
loss_fun = loss_fun.to(device)
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate) #SGD:梯度下降算法
#6、设置网络训练中的一些参数
total_train_step = 0 #记录总计训练次数
total_test_step = 0 #记录总计测试次数
Max_epoch = 10 #设计训练轮数
#7、开始进行训练
for epoch in range(Max_epoch):
print("---第{}轮训练开始---".format(epoch))
net.train() #开始训练,不是必须的,在网络中有BN,dropout时需要
#由于训练集数据较多,这里我没用训练集训练,而是采用测试集(test_dataset_loader)当训练集,但思想是一致的
for data in test_dataset_loader:
imgs, targets = data
targets = targets.to(device)
outputs = net(imgs)
#比较输出与真实值,计算Loss
loss = loss_fun(outputs, targets)
#反向传播,调整参数
optimizer.zero_grad() #每次让梯度重置
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 50 == 0:
print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))
if (epoch+1) % 2 == 0:
# 保存checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, './model/model_checkpoint_epoch_{}.tar'.format(epoch) # 这里的后缀名官方推荐使用.tar
)
if epoch > 5:
print("---意外中断---")
break
整个流程和这篇文章基本一致,不清楚的建议先花几分钟阅读一下哈。🍍🍍🍍主要区别就是在最后保存模型的时候我使用了checkpoint
进行保存,且两个epoch保存一次。当epoch=6时,我设置了一个break模拟程序意外中断,中断后可以来看一下终端的输出信息,如下图所示:
我们可以看到在进行第6轮循环时,程序中断了,此时最新的保存的模型是第五次训练结果,如下:
同时注意到第5次训练结束的loss在2.0左右,如果我们下次接着训练,损失应该是在2.0附近。🍊🍊🍊
好了,上面由于一些糟糕的原因导致程序中断了,现在我想接着上次训练的结果继续训练,我该怎么办呢?代码如下:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#1、准备数据集
train_dataset = torchvision.datasets.CIFAR10("./data", train=True, transform=torchvision.transforms.ToTensor(), download= True)
test_dataset = torchvision.datasets.CIFAR10("./data", train=False, transform=torchvision.transforms.ToTensor(), download= True)
#2、加载数据集
train_dataset_loader = DataLoader(dataset=train_dataset, batch_size=100)
test_dataset_loader = DataLoader(dataset=test_dataset, batch_size=100)
#3、搭建神经网络
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.model1 = nn.Sequential(
nn.Conv2d(3, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 32, 5, padding=2),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, 5, padding=2),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(1024, 64),
nn.Linear(64, 10)
)
def forward(self, input):
input = self.model1(input)
return input
#4、创建网络模型
net = Net()
#5、设置损失函数、优化器
#损失函数
loss_fun = nn.CrossEntropyLoss() #交叉熵
loss_fun = loss_fun.to(device)
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(net.parameters(), learning_rate) #SGD:梯度下降算法
#6、设置网络训练中的一些参数
total_train_step = 0 #记录总计训练次数
total_test_step = 0 #记录总计测试次数
Max_epoch = 10 #设计训练轮数
##########################################################################################
# 加载checkpoint
checkpoint = torch.load('./model/model_checkpoint_epoch_5.tar') # 先反序列化模型
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']
##########################################################################################
#7、开始进行训练
for epoch in range(start_epoch+1, Max_epoch):
print("---第{}轮训练开始---".format(epoch))
net.train() #开始训练,不是必须的,在网络中有BN,dropout时需要
for data in test_dataset_loader:
imgs, targets = data
targets = targets.to(device)
outputs = net(imgs)
#比较输出与真实值,计算Loss
loss = loss_fun(outputs, targets)
#反向传播,调整参数
optimizer.zero_grad() #每次让梯度重置
loss.backward()
optimizer.step()
total_train_step += 1
if total_train_step % 50 == 0:
print("---第{}次训练结束, Loss:{})".format(total_train_step, loss.item()))
if (epoch+1) % 2 == 0:
# 保存checkpoint
torch.save({
'epoch': epoch,
'model_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss
}, './model/model_checkpoint_epoch_{}.tar'.format(epoch) # 这里的后缀名官方推荐使用.tar
)
这里的代码相较之前的多了一个加载checkpoint
的过程,我将其截取出来,如下图所示:
通过加载checkpoint
我们就保存了之前训练的参数,进而实现断点续训练,我们直接来看执行此代码的结果,如下图所示:
从上图可以看出我们的训练是从第6轮开始的,并且初始的loss为1.99,和2.0接近。这就说明了我们已经实现了中断后恢复训练的操作。
🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸
这里我简单的说两句,上文介绍checkpoint
的用法时,训练中断和训练恢复我是放在两个文件中的进行的,但是在实际中我们肯定是在一个文件中运行,那这该怎么办呢?其实方法很简单啦,我们只需要设置一个if条件将加载checkpoint
的部分放在训练文件中,然后设置一个参数来控制if条件的执行即可。具体细节我就不给大家介绍了,如果有不明白的评论区见吧!!!🌿🌿🌿🌿
🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸🌸
来源:https://juejin.cn/post/7164004790416965640


猜你喜欢
- 之前用Crystal做了一个数字转English Word的Formula刚刚心血来潮, 大半个晚上写了JS版本的数字转换, 由于JS的Bu
- 1.垂直影像拼接 vconcat# -*- coding: utf-8 -*-import cv2image = cv2.imread(&q
- 有这样的情形,django个人头像在model中是:class UserProfile(AbstractUser): ""
- window.onresize = baiduResizeDiv; window.onerror = function(){} var di
- 本文实例为大家分享了python实现教务管理系统,供大家参考,具体内容如下mysql+python构成教务管理系统,提供系统管理员,教职工,
- 因为是看书自学的python,开始后不久就遇到了这个引入的模块函数,且一直在IDLE上编辑了后运行,试图从结果发现它的用途,然而结果一直都是
- 每个矿工将从先前创建的交易池中获取交易.要跟踪已挖掘的消息数量,我们必须创建一个全局变量 :last_transaction_index =
- 遇到一个难题,在无物理键盘情况下,通过页面软键盘在页面文本框输入汉字,不知道51js的各位大牛有没有遇到过这种需求,如果遇到过是如何解决的,
- 孤立帐户,就是某个数据库的帐户只有用户名而没有登录名,这样的用户在用户库的sysusers系统表中存在,而在master数据库的syslog
- 前言最近公司项目中在使用 Echarts 绘制图表时,由于默认的 label 标签不能满足设计稿需求,所以研究了对 label标签进行格式化
- 1 椭圆肤色检测模型原理:将RGB图像转换到YCRCB空间,肤色像素点会聚集到一个椭圆区域。先定义一个椭圆模型,然后将每个RGB像素点转换到
- Python中的最大整数Python中可以通过sys模块来得到int的最大值. python2中使用的方法是import sysmax =
- 在pandas里面常用value_counts确认数据出现的频率。1. Series 情况下:pandas 的 value_counts()
- 原理请查看前面几篇文章。1、数据源SH600519.csv 是用 tushare 模块下载的 SH600519 贵州茅台的日 k 线数据,本
- 一、平稳序列建模步骤假如某个观察值序列通过序列预处理可以判定为平稳非白噪声序列,就可以利用ARMA模型对该序列进行建模。建模的基本步骤如下:
- 起步Django 是个同步框架,本文并不是 让 Django 变成异步框架。而是对于在一个 view 中需要请求多次 http api 的场
- 1.zip用法简介在python 3.x系列中,zip方法返回的为一个zip object可迭代对象。class zip(object):&
- 什么是TensorboardXTensorboard 是 TensorFlow 的一个附加工具,可以记录训练过程的数字、图像等内容,以方便研
- 我想要向您介绍能想像到的开始 GUI 编程的最简单方法,就是使用 Scriptics 的 TK 和 Tkinter 封装器。我们将与 dev
- 如下所示:# -*- coding=utf-8 -*- import urllib2import socketimport timeurls