PyTorch加载模型model.load_state_dict()问题及解决
作者:是否龙磊磊真的一无所有 发布时间:2022-11-08 07:03:53
PyTorch加载模型model.load_state_dict()问题
希望将训练好的模型加载到新的网络上。
如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题。
Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不对应。
表明了加载过程中,期望获得的key值为feature...,而不是module.features....。
这是由模型保存过程中导致的,模型应该是在DataParallel模式下面,也就是采用了多GPU训练模型,然后直接保存的。
You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.
解决上面的问题有三个办法:
1. 对load的模型创建新的字典
去掉不需要的key值"module".
# original saved file with DataParallel
state_dict = torch.load('checkpoint.pt') # 模型可以保存为pth文件,也可以为pt文件。
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module.
new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。
# load params
model.load_state_dict(new_state_dict) # 从新加载这个模型。
2. 直接用空白''代替'module.'
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()})
# 相当于用''代替'module.'。
#直接使得需要的键名等于期望的键名。
3. 最简单的方法
加载模型之后,接着将模型DataParallel,此时就可以load_state_dict。
如果有多个GPU,将模型并行化,用DataParallel来操作。
这个过程会将key值加一个"module. ***"。
model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
print(k) #只打印key值,不打印具体参数。
4. 总结
从出错显示的问题就可以看出,key值不匹配,因此可以选择多种方法,将模型参数加载进去。
这个方法通常会在load_state_dict过程中遇到。将训练好的一个网络参数,移植到另外一个网络上面,继续训练。
或者将训练好的网络checkpoint加载进模型,再次进行训练。可以打印出model state_dict来看出两者的差别。
model = VGGNet()
params=model.state_dict() #获得模型的原始状态以及参数。
for k,v in params.items():
print(k) #只打印key值,不打印具体参数。
features.0.0.weight
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked
model = VGGNet()
checkpoint = torch.load('checkpoint.pt', map_location='cpu')
# Load weights to resume from checkpoint。
# print('**************************************')
# 这个方法能够直接打印出你保存的checkpoint的键和值。
for k,v in checkpoint.items():
print(k)
print("*****************************************")
输出结果为:
module.features.0.0.weight",
"module.features.0.1.weight",
"module.features.0.1.bias
可以看出不匹配,模型的参数中,key值不同,多了module。
PS: 追加
在移植参数的过程中,对于出现 .total_ops和.total_params结尾的参数,可参考以下代码:
from collections import OrderedDict
checkpoint = torch.load(
pretrained_model_file_path,
map_location=(None if use_cuda and not remap_to_cpu else "cpu"))
new_state_dict = OrderedDict()
for k, v in checkpoint.items():
if not k.endswith('total_ops') and not k.endswith('total_params'):
name = k[7:]
new_state_dict[name] = v
最后
来源:https://blog.csdn.net/qq_32998593/article/details/89343507


猜你喜欢
- 今天给大家介绍一个电商中常见的场景 —— MySQL 数据同步 Elasticsearch。商品检索
- 学习Python过程中,发现没有switch-case,过去写C习惯用Switch/Case语句,官方文档说通过if-elif实现。所以不妨
- 实验1.1 列表a = [1, 2, 3, 4]for i in a: print(i)  
- 摘要在Nginx和uWSGI还没配置时,单独在url.py使用apscheduler设置定时任务,使用python manage.py ru
- 本文实例讲述了Python实现获取本地及远程图片大小的方法。分享给大家供大家参考,具体如下:了解过Pillow的都知道,Pillow是一个非
- 前言现在在疫情阶段,想找一份不错的工作变得更为困难,很多人会选择去网上看招聘信息。可是招聘信息有一些是错综复杂的。而且不能把全部的信息全部罗
- 1.基本构架:mport PIL.Image 相关模块img=Image.open(img_name) 打开图片img.save(save_
- javascript作为一个动态语言,动态解析脚本的方法非常多,如万恶又万能的eval,低调的Function,IE独占的execScrip
- 首先看一下分页的基本原理:mysql> explain SELECT * FROM message ORDER BY id DESC
- 1、MySQL8.0.16解压其中dada文件夹和my.ini配置文件是解压后手动加入的,如下图所示2、新建配置文件my.ini放在D:\F
- 一、百度百科1、MySQLMySQL声称自己是最流行的开源数据库。LAMP中的M指的就是MySQL。构建在LAMP上的应用都会使用MySQL
- 目录1.自动移动鼠标,以便Skype / Lynk显示你在工作中处于活动状态2.使用Selenium自动化网站登录过程3.自动文件备份4.自
- 以下内容都是针对Pytorch 1.0-1.1介绍。很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解
- 列表(list)和元组(tuple)的一些基础list和tuple都是一个可以放置任意数据类型的有序集合,都是既可以存放数字、字符串、对象等
- 本文实例讲述了Python 网络编程之TCP客户端/服务端功能。分享给大家供大家参考,具体如下:demo.py(TCP客户端):import
- 这篇文章主要介绍了如何使用Python多线程测试并发漏洞,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- mysql最常用的索引结构是btree(O(log(n))),但是总有一些情况下我们为了更好的性能希望能使用别的类型的索引。hash就是其中
- 这篇文章主要讲TensorFlow中的Session的用法以及Variable。Session会话控制Session是TensorFlow为
- 我对这两种连接方式认识不够深,似乎朋友们对此也没有定论。请问哪一种更好呢?DSN是采用数据源的连接方式,其使用方法是: Conn.
- 一、为什么要安装虚拟环境 情景一、项目A需要某个库的1.0版本,项目B需要这个库的2.0版本。如果没有安装虚拟环境