Pytorch模型的保存/复用/迁移实现代码
作者:信海 发布时间:2023-12-12 10:37:43
标签:Pytorch,模型,保存,迁移
本文整理了Pytorch框架下模型的保存、复用、推理、再训练和迁移等实现。
模型的保存与复用
模型定义和参数打印
# 定义模型结构
class LenNet(nn.Module):
def __init__(self):
super(LenNet, self).__init__()
self.conv = nn.Sequential( # [batch, 1, 28, 28]
nn.Conv2d(1, 8, 5, 2), # [batch, 1, 28, 28]
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [batch, 8, 14, 14]
nn.Conv2d(8, 16, 5), # [batch, 16, 10, 10]
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [batch, 16, 5, 5]
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(16*5*5, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Linear(64, 10)
)
def forward(self, X):
return self.fc(self.conv(X))
# 查看模型参数
# 网络模型中的参数model.state_dict()是以字典形式保存(实质上是collections模块中的OrderedDict)
model = LenNet()
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
# 参数名中的fc和conv前缀是根据定义nn.Sequential()时的名字所确定。
# 参数名中的数字表示每个Sequential()中网络层所在的位置。
print(model.state_dict().keys()) # 打印键
print(model.state_dict().values()) # 打印值
# 优化器optimizer的参数打印类似
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
模型保存
import os
# 指定保存的模型名称时Pytorch官方建议的后缀为.pt或者.pth
model_save_dir = './model_logs/'
model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
torch.save(model.state_dict(), model_save_path)
# 在训练过程中保存某个条件下的最优模型,可以如下操作
best_model_state = deepcopy(model.state_dict())
torch.save(best_model_state, model_save_path)
# 下面这种方法是错误的,因为best_model_state只是model.state_dict()的引用,会随着训练的改变而改变
best_model_state = model.state_dict()
torch.save(best_model_state, model_save_path)
模型推理
def inference(data_iter, device, model_save_dir):
model = LeNet() # 初始化现有模型的权重参数
model.to(device)
model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
# 如果本地存在模型,则加载本地模型参数覆盖原有模型
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
model.load_state_dict(loaded_paras)
model.eval()
with torch.no_grad(): # 开始推理
acc_sum, n = 0., 0
for x, y in data_iter:
x, y = x.to(device), y.to(device)
logits = model(x)
acc_sum += (logits.argmax(1) == y).float().sum().item()
n += len(y)
print("Accuracy in test data is : ", acc_sum / n)
模型再训练
class MyModel:
def __init__(self,
batch_size=64,
epochs=5,
learning_rate=0.001,
model_save_dir='./MODEL'):
self.batch_size = batch_size
self.epochs = epochs
self.learning_rate = learning_rate
self.model_save_dir = model_save_dir
self.model = LeNet()
def train(self):
train_iter, test_iter = load_dataset(self.batch_size)
# 在训练过程中只保存网络权重,在再训练时只载入网络权重参数初始化网络训练。这里是核心部分,开始。
if not os.path.exists(self.model_save_dir):
os.makedirs(self.model_save_dir)
model_save_path = os.path.join(self.model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
self.model.load_state_dict(loaded_paras)
print("#### 成功载入已有模型,进行再训练...")
# 结束
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(device)
for epoch in range(self.epochs):
for i, (x, y) in enumerate(train_iter):
x, y = x.to(device), y.to(device)
loss, logits = self.model(x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
acc = (logits.argmax(1) == y).float().mean()
print("Epochs[{}/{}]---batch[{}/{}]---acc {:.4}---loss {:.4}".format(
epoch, self.epochs, len(train_iter), i, acc, loss.item()))
print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
self.evaluate(test_iter, self.model, device)))
torch.save(self.model.state_dict(), model_save_path)
@staticmethod
def evaluate(data_iter, model, device):
with torch.no_grad():
acc_sum, n = 0.0, 0
for x, y in data_iter:
x, y = x.to(device), y.to(device)
logits = model(x)
acc_sum += (logits.argmax(1) == y).float().sum().item()
n += len(y)
return acc_sum / n
# 在保存参数的时候,将优化器参数、损失值等可一同保存,然后在恢复模型时连同其它参数一起恢复
model_save_path = os.path.join(model_save_dir, 'LeNet.pt')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, model_save_path)
# 加载方式如下
checkpoint = torch.load(model_save_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
模型迁移
# 定义新模型NewLeNet 和LeNet区别在于新增了一个全连接层
class NewLenNet(nn.Module):
def __init__(self):
super(NewLenNet, self).__init__()
self.conv = nn.Sequential( # [batch, 1, 28, 28]
nn.Conv2d(1, 8, 5, 2), # [batch, 1, 28, 28]
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [batch, 8, 14, 14]
nn.Conv2d(8, 16, 5), # [batch, 16, 10, 10]
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [batch, 16, 5, 5]
)
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(16*5*5, 128),
nn.ReLU(inplace=True),
nn.Linear(128, 64), # 这层以前和LeNet结构一致 可以用LeNet的参数来进行替换
nn.ReLU(inplace=True),
nn.Linear(64, 32),
nn.ReLU(inplace=True),
nn.Linear(32, 10)
)
def forward(self, X):
return self.fc(self.conv(X))
# 定义替换函数 匹配两个网络 size相同处地方进行参数替换
def para_state_dict(model, model_save_dir):
state_dict = deepcopy(model.state_dict())
model_save_path = os.path.join(model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
for key in state_dict: # 在新的网络模型中遍历对应参数
if key in loaded_paras and state_dict[key].size() == loaded_paras[key].size():
print("成功初始化参数:", key)
state_dict[key] = loaded_paras[key]
return state_dict
# 更新一下模型迁移后的训练代码
def train(self):
train_iter, test_iter = load_dataset(self.batch_size)
if not os.path.exists(self.model_save_dir):
os.makedirs(self.model_save_dir)
model_save_path = os.path.join(self.model_save_dir, 'model_new.pt')
old_model = os.path.join(self.model_save_dir, 'LeNet.pt')
if os.path.exists(old_model):
state_dict = para_state_dict(self.model, self.model_save_dir) # 调用迁移代码 将LeNet的前几层参数迁移到NewLeNet
self.model.load_state_dict(state_dict)
print("#### 成功载入已有模型,进行再训练...")
optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(device)
for epoch in range(self.epochs):
for i, (x, y) in enumerate(train_iter):
x, y = x.to(device), y.to(device)
loss, logits = self.model(x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
acc = (logits.argmax(1) == y).float().mean()
print("Epochs[{}/{}]---batch[{}/{}]---acc {:.4}---loss {:.4}".format(
epoch, self.epochs, len(train_iter), i, acc, loss.item()))
print("Epochs[{}/{}]--acc on test {:.4}".format(epoch, self.epochs,
self.evaluate(test_iter, self.model, device)))
torch.save(self.model.state_dict(), model_save_path)
# 这里更新未进行训练的推理
def inference(data_iter, device, model_save_dir='./MODEL'):
model = NewLeNet() # 初始化现有模型的权重参数
print("初始化参数 conv.0.bias 为:", model.state_dict()['conv.0.bias'])
model.to(device)
state_dict = para_state_dict(model, model_save_dir) # 迁移模型参数
model.load_state_dict(state_dict)
model.eval()
print("载入本地模型重新初始化 conv.0.bias 为:", model.state_dict()['conv.0.bias'])
with torch.no_grad():
acc_sum, n = 0.0, 0
for x, y in data_iter:
x, y = x.to(device), y.to(device)
logits = model(x)
acc_sum += (logits.argmax(1) == y).float().sum().item()
n += len(y)
print("Accuracy in test data is :", acc_sum / n)
参考文献
[1] https://github.com/moon-hotel/DeepLearningWithMe
来源:https://www.cnblogs.com/yuxuliang/p/Mytorch_01.html


猜你喜欢
- Github 项目主页 工具源码分析结果:total : 15981 1568.0 == Backspace 1103.0 == Tab 1
- if (context.Request.UserAgent.ToLower().IndexOf(&qu
- 前言不知道什么是版本库的,扇自己两个大嘴巴;知道但不用的,扇自己四个大嘴巴。快扇去。你真扇了?那你是个大傻瓜。扇什么扇,有扇自己的功夫,还不
- 本文实例为大家分享了js实现QQ邮箱邮件拖拽删除的具体代码,供大家参考,具体内容如下步骤分析:根据数据结构生成HTML结构全选和单选功能的实
- 前言Go 1.3 的sync包中加入一个新特性:Pool。这个类设计的目的是用来保存和复用临时对象,以减少内存分配,降低CG压力。type
- 页级的典型代表引擎为BDB。 表级的典型代表引擎为MyISAM,MEMORY以及很久以前的ISAM。 行级的典型代表引擎为INNODB。 -
- 设计原理从结构上来说,一个简单的图形界面,需要由界面组件、组件的事件 * (响应各类事件的逻辑)和具体的事件处理逻辑组成。界面实现的主要工作
- 一、pyqt5中动画的继承关系图二、关于QAbstractAnimation父类的认识1、主要作用继承此类, 实现一些自定义动画所有动画共享
- python实现四舍五入""" 四舍五入 :param
- 1、半开放socket利用shutdown()函数使socket双向数据传输变为单向数据传输。shutdown()需要一个单独的参数,该参数
- 1、打开本地企业管理器,先创建一个SQL Server注册来远程连接服务器端口SQL Server。步骤如下图:图1:2、弹出窗口后输入内容
- 本文实例为大家分享了JS编写简单选项卡的具体代码,供大家参考,具体内容如下<!DOCTYPE html><html lan
- 把value插入dataframe的指定位置loc中,若插入的数据value已在DataFrame中,则返回 错误ValueError,如想
- Golang中Array是值类型而slice是引用类型。因此两者之间的赋值或拷贝有些差异,本文带你了解各自的差异。1. 拷贝array前面提
- 在这个教材中,我们假定你已经安装了Scrapy。假如你没有安装,你可以参考这个安装指南。我们将会用开放目录项目(dmoz)作为我
- 这篇文章主要介绍了使用python远程操作linux过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 一、前言最近忙着在服务器上跑代码学习积累了一些经验技巧这里用来记录分享给大家二、创建虚拟环境用来跑代码下面我会以一个实例为模板,学习完之后,
- MySQL BETWEEN 用法MySQL BETWEEN 语法BETWEEN 运算符用于 WHERE 表达式中,选取介于两个值之间的数据范
- 1.delete不能使自动编号返回为起始值。但是truncate能使自动增长的列的值返回为默认的种子 2.truncate只能一次清空,不能
- 运行代码时,出现诸如这样的文件的权限有可能出问题,不过更多是路径本身有问题。比如,你的文件名是否正确,路径是否正确,路径后面是不是多了什么奇