解决Pytorch修改预训练模型时遇到key不匹配的情况
作者:月亮不秃头 发布时间:2022-11-29 15:43:43
一、Pytorch修改预训练模型时遇到key不匹配
最近想着修改网络的预训练模型vgg.pth,但是发现当我加载预训练模型权重到新建的模型并保存之后。
在我使用新赋值的网络模型时出现了key不匹配的问题
#加载后保存(未修改网络)
base_weights = torch.load(args.save_folder + args.basenet)
ssd_net.vgg.load_state_dict(base_weights)
torch.save(ssd_net.state_dict(), args.save_folder + 'ssd_base' + '.pth')
# 将新保存的网络代替之前的预训练模型
ssd_net = build_ssd('train', cfg['min_dim'], cfg['num_classes'])
net = ssd_net
...
if args.resume:
...
else:
base_weights = torch.load(args.save_folder + args.basenet)
#args.basenet为ssd_base.pth
print('Loading base network...')
ssd_net.vgg.load_state_dict(base_weights)
此时会如下出错误:
Loading base network…
Traceback (most recent call last):
File “train.py”, line 264, in
train()
File “train.py”, line 110, in train
ssd_net.vgg.load_state_dict(base_weights)
…
RuntimeError: Error(s) in loading state_dict for ModuleList:
Missing key(s) in state_dict: “0.weight”, “0.bias”, … “33.weight”, “33.bias”.
Unexpected key(s) in state_dict: “vgg.0.weight”, “vgg.0.bias”, … “vgg.33.weight”, “vgg.33.bias”.
说明之前的预训练模型 key参数为"0.weight", “0.bias”,但是经过加载保存之后变为了"vgg.0.weight", “vgg.0.bias”
我认为是因为本身的模型定义文件里self.vgg = nn.ModuleList(base)这一句。
现在的问题是因为自己定义保存的模型key参数多了一个前缀。
可以通过如下语句进行修改,并加载
from collections import OrderedDict #导入此模块
base_weights = torch.load(args.save_folder + args.basenet)
print('Loading base network...')
new_state_dict = **OrderedDict()**
for k, v in base_weights.items():
name = k[4:] # remove `vgg.`,即只取vgg.0.weights的后面几位
new_state_dict[name] = v
ssd_net.vgg.load_state_dict(new_state_dict)
此时就不会再出错了。
参考了这个篇。修改一下就可以应用到自己的模型啦。
//www.jb51.net/article/214214.htm
二、pytorch加载预训练模型遇到的问题:KeyError: ‘bn1.num_batches_tracked‘
最近在使用pytorch1.0加载resnet预训练模型时,遇到的一个问题,在此记录一下。
KeyError: 'layer1.0.bn1.num_batches_tracked'
其实是使用的版本的问题,pytorch0.4.1之后在BN层加入了track_running_stats这个参数,
这个参数的作用如下:
训练时用来统计训练时的forward过的min-batch数目,每经过一个min-batch, track_running_stats+=1
如果没有指定momentum, 则使用1/num_batches_tracked 作为因数来计算均值和方差(running mean and variance).
其实,这个参数没啥用.但因为官方提供的预训练模型是pytorch0.3版本训练出来的,因此没有这个参数.
所以,只要过滤一下预训练权重字典中的关键字即可,‘num_batches_tracked'.代码例子,如下.
有问题的代码:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
for i in state_dict:
key = param_name + '.' + i
state_dict[i].copy_(param_dict[key])
del param_dict
对'num_batches_tracked进行过滤:
def load_specific_param(self, state_dict, param_name, model_path):
param_dict = torch.load(model_path)
param_dict = {k: v for k, v in param_dict.items() if 'num_batches_tracked' not in k}
for i in state_dict:
key = param_name + '.' + i
if 'num_batches_tracked' in key:
continue
state_dict[i].copy_(param_dict[key])
del param_dict
来源:https://blog.csdn.net/weixin_44039925/article/details/99447653


猜你喜欢
- 场景:由于自己的电脑A性能不足,需要转移到一台高性能的主机B上运行python程序,但是该主机不能连接互联网。问题:在个人电脑A上建立了一个
- 二元函数为y=x1^2+x2^2,x∈[-5,5]NIND=121; %初始种群的个数(Number of individual
- 本文实例为大家分享了pygame实现雷电游戏开发代码,供大家参考,具体内容如下源代码:stars.py#-*- coding=utf-8 -
- 本地路径的创建在做下载操作时,我们一般先把文件下载到本地指定的路径下,然后再做其他使用。为了防止程序出现异常,我们通常需要先判断本地是否存在
- 全球数据量的疯狂增长,使得市场对资深数据库管理员的需求也节节攀升。据统计,一直到2016美国IT市场对数据库管理员的需求量增长都将会超过所有
- 1.由于数据库设计问题造成SQL数据库新增数据时超时症状:Microsoft OLE DB Provider for SQL Server
- 前言SQL模式影响MySQL支持的SQL语法和执行的数据验证检查。MySQL服务器可以在不同的SQL模式下运行,并且可以针对不同的客户端以不
- 本文实例讲述了redis数据库及与python交互用法。分享给大家供大家参考,具体如下:redis数据操作1.string类型:主要存储字符
- 1. dataloader() 初始化函数def __init__(self, dataset, batch_size=1, shuffle
- 什么是进程进程就是操作系统中执行的一个程序,操作系统以进程为单位分配存储空间,每个进程都有自己的地址空间、数据栈以及其他用于跟踪进程执行的辅
- 本文实例讲述了javascript修改图片src的方法。分享给大家供大家参考。具体实现方法如下:<!DOCTYPE html>&
- 三维可视化系统的建立依赖于三维图形平台, 如 OpenGL、VTK、OGRE、OSG等, 传统的方法多采用OpenGL进行底层编程,即对其特
- 最简单的:<textarea name="A" cols="45" rows="2&
- Python中的三引号,3个单引号及3个双引号实际上3个单引号和3个双引号不经常用,但是在某些特殊格式的字符串下却有大用处。通常情况下我们用
- 一、首先我们来填个坑支付验签失败这个问题折磨了我两天,官方文档比较含糊不清。各种百度下来的方法试过之后也不尽人意,最后发现问题是没有二次签名
- 无意中看到百度的页面代码,想到了一种声明写法,需要的朋友可以参考下。<!DOCTYPE html> <!--[if IE]
- collections是实现了特定目标的容器,以提供Python标准内建容器 dict , list , set , 和 tuple 的替代
- Python 循环Python 有两个原始的循环命令:while 循环for 循环while 循环如果使用 while 循环,只要条件为真,
- 目录一、Python GUI 编程简介二、流行GUI框架总结三、代码演示四、界面一、Python GUI 编程简介Tkinter 模块(Tk
- 太长不看的简洁版本1.x = np.arange(start, end, steps)Values are generated within