pytorch模型保存与加载中的一些问题实战记录
作者:colourmind 发布时间:2021-09-03 21:41:50
前言
最近使用pytorch训练模型,保存模型后再次加载使用出现了一些问题。记录一下解决方案!
一、torch中模型保存和加载的方式
1、模型参数和模型结构保存和加载
torch.save(model,path)
torch.load(path)
2、只保存模型的参数和加载——这种方式比较安全,但是比较稍微麻烦一点点
torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)
二、torch中模型保存和加载出现的问题
1、单卡模型下保存模型结构和参数后加载出现的问题
模型保存的时候会把模型结构定义文件路径记录下来,加载的时候就会根据路径解析它然后装载参数;当把模型定义文件路径修改以后,使用torch.load(path)就会报错。
把model文件夹修改为models后,再加载就会报错。
import torch
from model.TextRNN import TextRNN
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)
这种保存完整模型结构和参数的方式,一定不要改动模型定义文件路径。
2、多卡机器单卡训练模型保存后在单卡机器上加载会报错
在多卡机器上有多张显卡0号开始,现在模型在n>=1上的显卡训练保存后,拷贝在单卡机器上加载
import torch
from model.TextRNN import TextRNN
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)
会出现cuda device不匹配的问题——你保存的模代码段 小部件型是使用的cuda1,那么采用torch.load()打开的时候,会默认的去寻找cuda1,然后把模型加载到该设备上。这个时候可以直接使用map_location来解决,把模型加载到CPU上即可。
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
3、多卡训练模型保存模型结构和参数后加载出现的问题
当用多GPU同时训练模型之后,不管是采用模型结构和参数一起保存还是单独保存模型参数,然后在单卡下加载都会出现问题
a、模型结构和参数一起保然后在加载
torch.distributed.init_process_group(backend='nccl')
模型训练的时候采用上述多进程的方式,所以你在加载的时候也要声明,不然就会报错。
b、单独保存模型参数
model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)
同样会出现问题,不过这里出现的问题是参数字典的key和模型定义的key不一样
原因是多GPU训练下,使用分布式训练的时候会给模型进行一个包装,代码如下:
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)
包装前的模型结构:
包装后的模型
在外层多了DistributedDataParallel以及module,所以才会导致在单卡环境下加载模型权重的时候出现权重的keys不一致。
三、正确的保存模型和加载的方法
if gpu_count > 1:
torch.save(model.module.state_dict(),save_path)
else:
torch.save(model.state_dict(),save_path)
model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load(save_path)
model.load_state_dict(state_dict)
这样就是比较好的范式,加载不会出错。
来源:https://blog.csdn.net/HUSTHY/article/details/115199280


猜你喜欢
- 我们也可以来做一个,但这个“定时器”的工作时间范围应控制在1个小时至100 毫秒之间: <%sub StartTi
- 1.intersect为取多个查询结果的交集;2.查询两个基本时间段内表记录的SQL语句;select * from shengjibiao
- 详细代码见仓库github地址:github.com/nerkeler/account重要提示程序默认密码:password密钥位置:./r
- 一、卷积神经网络卷积神经网络(ConvolutionalNeuralNetwork,CNN)最初是为解决图像识别等问题设计的,CNN现在的应
- 经常会有小朋友问我,“我想做个黑客,我该学什么编程语言?”,或者有的小朋友会说:“我要学c,我要做病毒”。其实对于这些小朋友而言他们基本都没
- 常有新手问我该怎么备份数据库,下面介绍3种备份数据库的方法:(1)备份数据库文件MySQL中的每一个数据库和数据表分别对应文件系统中的目录和
- 大家都知道,IE中的现代事件绑定(attachEvent)与W3C标准的(addEventListener)相比存在很多问题,例如:内存泄漏
- 前言读取站点资料数据对站点数据进行插值,插值到规则网格上绘制EOF第一模态和第二模态的空间分布图绘制PC序列关于插值,这里主要提供了两个插值
- 引言通过前面的文章我们已经了解到OpenCV 是一个用于计算机视觉和机器学习的开源 python 库。它主要针对实时计算机视觉和图像处理。它
- 因为主键可以唯一标识某一行记录,所以可以确保执行数据更新、删除的时候不会出现张冠李戴的错误。当然,其它字段可以辅助我们在执行这些操作时消除共
- Tensorflow数据读取有三种方式:Preloaded data: 预加载数据Feeding: Python产生数据,再把数据喂给后端。
- 最近看到好多人说到tns或者数据库不能登录等问题,就索性总结了下面的文档。首先来说Oracle的网络结构,往复杂处说能加上加密、LDAP等等
- 前言总结一下最近看的关于opencv图像几何变换的一些笔记. 这是原图: 1.平移import cv2import numpy as npi
- 在ASP中,你可通过VBScript和其他方式调用自程序。实例:调用使用VBScript的子程序如何从ASP调用以VBScript编写的子程
- 把昨天做的高级查询界面完善了一下,支持动态添加多个查询条件、定义逻辑关系,支持整形、浮点、字符串、日期、布尔值、自定义选择列表的录入,通过E
- 为cd2sc.com网站功能而开发,代码为本人原创,生成速度一般。 (出于众所周知的原因,涉及到数据库的数据字段名称做了改动,并且为了代码明
- 协程的特点1.该任务的业务代码主动要求切换,即主动让出执行权限2.发生了IO,导致执行阻塞(使用channel让协程阻塞)与线程本质的不同C
- 引言事情是这样的,最近在做开源软件供应链安全相关的项目,之前没了解这方面知识的时候感觉服务器被黑,数据库被删,网站被攻,这些东西都离我们太遥
- SQL Server 数据库定时自动备份,供大家参考,具体内容如下在SQL Server中出于数据安全的考虑,所以需要定期的备份数据库。而备
- 一、SQL注入简介SQL注入是比较常见的网络攻击方式之一,它不是利用操作系统的BUG来实现攻击,而是针对程序员编写时的疏忽,通过SQL语句,