Pytorch中实现只导入部分模型参数的方式
作者:咆哮的阿杰 发布时间:2023-01-24 05:53:25
我们在做迁移学习,或者在分割,检测等任务想使用预训练好的模型,同时又有自己修改之后的结构,使得模型文件保存的参数,有一部分是不需要的(don't expected)。我们搭建的网络对保存文件来说,有一部分参数也是没有的(missed)。如果依旧使用torch.load(model.state_dict())的办法,就会出现 xxx expected,xxx missed类似的错误。那么在这种情况下,该如何导入模型呢?
好在Pytorch中的模型参数使用字典保存的,键是参数的名称,值是参数的具体数值。我们使用model.state_dict()获得这个字典,之后就能利用参数名称来实现导入。
请看下面的一个例子。
我们先搭建一个小小的网络。
import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,32,3,1)
self.conv2 = nn.Conv2d(32,3,3,1)
self.w = nn.Parameter(t.randn(3,10))
for p in self.children():
nn.init.xavier_normal_(p.weight.data)
nn.init.constant_(p.bias.data, 0)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(x)
out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
out = F.linear(out,weight=self.w)
return out
然后我们保存这个网络的初始值。
model = Net()
t.save(model.state_dict(),'xxx.pth')
现在我们将Net修改一下,多加几个卷积层,但并不加入到forward中,仅仅出于少些几行的目的。
import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 3, 3, 1)
self.conv3 = nn.Conv2d(3,64,3,1)
self.conv4 = nn.Conv2d(64,32,3,1)
for p in self.children():
nn.init.xavier_normal_(p.weight.data)
nn.init.constant_(p.bias.data, 0)
self.w = nn.Parameter(t.randn(3, 10))
def forward(self, x):
out = self.conv1(x)
out = self.conv2(x)
out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
out = F.linear(out, weight=self.w)
return out
我们现在试着导入之前保存的模型参数。
path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
'''
RuntimeError: Error(s) in loading state_dict for Net:
Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias".
'''
出现了没有在模型文件中找到error中的关键字的错误。
现在我们这样导入模型
path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)
看看上面的代码,很容易弄明白。其中model_dict.update的作用是更新代码中搭建的模型参数字典。为啥更新我其实并不清楚,但这一步骤是必须的,否则还会报错。
为了弄清楚为什么要更新model_dict,我们不妨分别输出state_dict和model_dict的关键值看一看。
for k in state_dict.keys():
print(k)
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
print(k)
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''
这个结果也是预料之中的,所以我猜测,update之后,model_dict和state_dict中具有相同键的值已经同步了。updata的目的就是使model_dict带有state_dict中都具有的那一部分参数的值,对于model_dict中有的,但是save_dict中没有的参数,值不改变,参数仍然使用初始值。
来源:https://blog.csdn.net/qq_34914551/article/details/87871134
猜你喜欢
- 如果想设置相同的初值和想要的长度>>> a=[None]*4>>> print(a)[None, Non
- '定义变量 Dim cn,rs,Sql Sql = "sel
- 1. 张量的拼接(1) numpy.concatenatenp.concatenate((a1,a2,a3,…), axis=0)张量的拼接
- Python则是通过缩进来识别代码块的。缩进Python最具特色的是用缩进来标明成块的代码。我下面以if选择结构来举例。if后面跟随条件,如
- 1. 问题使用PyCharm 创建完Django 项目 想登录admin 页面 却不知道用户名和密码。 用的默认sqlit2.解决办法2.1
- 前言最近学完Python,写了几个爬虫练练手,网上的教程有很多,但是有的已经不能爬了,主要是网站经常改,可是爬虫还是有通用的思路的,即下载数
- 一、读写txt文件1、打开txt文件file_handle=open('1.txt',mode='w')上述
- 1-删除模型变量del model_define2-清空CUDA cachetorch.cuda.empty_cache()3-步骤2(异步
- 今天为大家介绍几个Python“装逼”实例代码,python绘制樱花、玫瑰、圣诞树代码实例,主要使用了turtle库Python绘制樱花代码
- 简介网上流传的部分可以百度关键词“Python”和“word”后查看文章学习,以下内容为个人实践,修正了不能运行出错的情况。代码示例impo
- 阅读上一篇:Freshow工具使用方法一. eval加密是在网马解密中最常见的,eval在jscript脚本中实际上是一个函数,简单可以理解
- 在使用Django做前端后端项目时,登陆认证方法往往使用的是jwt_token,但是想自定义登陆成功和失败的返回体。1.当用户名和密码正确就
- django settings.py 配置文件import osBASE_DIR = os.path.dirname(os.path.dir
- 简介最近在整理我们项目代码的时候,发现有很多活动的代码在结构和提供的功能上都非常相似。为了方便今后的开发,我花了一点时间编写了一个生成代码框
- 【简 介】熟悉网页设计的网友就知道,调用Style的方法很多,我们可以单击鼠标右键选择Custon Style来调用Style标准,也可以在
- 安装pyecharts1.8.0版本后导入pyecharts模块绘图时报错: “所有图表类型将在 v1.9.0 版本开始强制使用 Chart
- numpy中有一个掩码数组的概念,需要通过子模块numpy.ma来创建,基本的创建方式如下>>> import numpy
- 在密码学中,ElGamal加密算法是一个基于迪菲-赫尔曼密钥交换的非对称加密算法。它在1985年由塔希尔·盖莫尔提出。GnuPG和PGP等很
- 我考虑到了x的所有n次的情况,下面的代码有可能是不完美的,但是肯定是对的。def aaa(x,n): A=isinstance(
- 这篇文章主要介绍了Python for循环通过序列索引迭代过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价