解决Pytorch修改预训练模型时遇到key不匹配的情况
作者:月亮不秃头 发布时间:2022-11-29 15:43:43
一、Pytorch修改预训练模型时遇到key不匹配
最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。
在我使用新赋值的网络模型时出现了key不匹配的问题
#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights)
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
net = ssd_net
...
if args.resume:
...
else:
base_weights = torch.load(args.save_folder + args.basenet)
#args.basenet为ssd_base.pth
print('Loading base network...')
ssd_net.vgg.load_state_dict(base_weights)
此时会如下出错误:
Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)
…
RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.
说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”
我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。
现在的问题是因为自己定义保存的模型key参数多了一个前缀。
可以通过如下语句进行修改,并加载
from collections import OrderedDict #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**
for k, v in base_weights.items():
name = k[4:] # remove `vgg.`,即只取vgg.0.weights的后面几位
new_state_dict[name] = v
ssd_net.vgg.load_state_dict(new_state_dict)
此时就不会再出错了。
参考了这个篇。修改一下就可以应用到自己的模型啦。
//www.jb51.net/article/214214.htm
二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘
最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。
KeyError: 'layer1.0.bn1.num_batches_tracked'
其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,
这个参数的作用如下:
训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1
如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).
其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.
所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.
有问题的代码:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
for i in state_dict:
key = param_name + '.' + i
state_dict[i].copy_(param_dict[key])
del param_dict
对'num_batches_tracked进行过滤:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
for i in state_dict:
key = param_name + '.' + i
if 'num_batches_tracked' in key:
continue
state_dict[i].copy_(param_dict[key])
del param_dict
来源:https://blog.csdn.net/weixin_44039925/article/details/99447653
猜你喜欢
- 下面展示一下非瀑布流的item布局情况,每个item的高度都是一样的,所以 他的index就是左右左右,position所对应的itemVi
- map( )函数在算法题目里面经常出现,map( )会根据提供的函数对指定序列做映射,在写返回值等需要转换的时候比较常用。关于映射map,可
- python使用utf8编码,mysql也是utf8编码,是什么问题?后来查了一下,使用一个简单的办法即可:vsql = "ins
- https://discuss.pytorch.org/t/how-to-modify-the-final-fc-layer-based-o
- 在这篇文章(不敢妄称教程,最多称之为学习笔记)里,我会从头开始实现客户端模板的效果。不过你不要期望能够在这里找到可以直接拿去使用直接复用灵活
- 作为一个新手,你需要以下3个步骤:1、用户注册 > 2、获取token > 3、调取数据数据内容:包含股票、基金、期货、债券、外
- 在Python的网络编程中,getservbyport()函数和getservbyname()函数是socket模块中的两个函数,因此在使用
- 一、什么是数组数组就是一组数据的集合,把一系列数据组织起来,形成一个可操作的整体。数组的每个实体都包含两项:键和值。二、声明数据在PHP中声
- 设计原理从结构上来说,一个简单的图形界面,需要由界面组件、组件的事件 * (响应各类事件的逻辑)和具体的事件处理逻辑组成。界面实现的主要工作
- 这篇文章主要介绍了python解析命令行参数的三种方法详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- 音乐播放器可让您快速轻松地管理和收听所有音乐文件。在本文中,我将带您了解如何使用 Python 创建音乐播放器 GUI。如何使用 Pytho
- asp使用SQL语句,查询数据库中的第10-20条记录的l方法,两种sql语句写法如下:1、select top 10 * from tab
- 所谓定时器,是指间隔特定时间执行特定任务的机制。几乎所有的编程语言,都有定时器的实现。比如,Java有util.Timer和util.Tim
- 为什么要引入线程池如果在程序中经常要用到线程,频繁的创建和销毁线程会浪费很多硬件资源,所以需要把线程和任务分离。线程可以反复利用,省去了重复
- 本文实例讲述了Python加pyGame实现的简单拼图游戏。分享给大家供大家参考。具体实现方法如下:import pygame, sys,
- 目前,各大搜索引擎如google、百度、雅虎已经对动态页面诸如asp,php有着不错的支持了,只要动态页面后面的参数不要太长,如控制在3个参
- Python 语言与 Perl,C 和 Java 等语言有许多相似之处,也有一定的差异性,以下是Python语言获取文件后缀名和文件名的方法
- 在一般问题的优化中,最速下降法和共轭梯度法都是非常有用的经典方法,但最速下降法往往以”之”字形下降,速度较慢,不能很快的达到最优值,共轭梯度
- 创建main.py文件并粘贴下面代码点击右键运行Debug 'main'后,下方的Debug窗口会出现ImportError
- Java一直标榜一句老话叫“编写一次,到处运行(Write Once,Run Anywhere)”,CSS也差一点点做到了。但就是为了差的一