pytorch加载预训练模型与自己模型不匹配的解决方案
作者:找不到服务器1703 发布时间:2023-06-17 14:22:24
标签:pytorch,加载,预训练,模型
pytorch中如果自己搭建网络并且加载别人的与训练模型的话,如果模型和参数不严格匹配,就可能会出问题,接下来记录一下我的解决方法。
两个有序字典找不同
模型的参数和pth文件的参数都是有序字典(OrderedDict),把字典中的键转为列表就可以在for循环里迭代找不同了。
model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
err = 1
自己搭建模型的注意事项
搭网络时要对照pth文件的字典顺序搭,字典顺序、权重尺寸(shape)和变量命名必须与pth文件完全一致。如果仅仅是变量命名不同,可采用类似的方法对模型的权重重新赋值。
model = ResNet18(1)
model_dict1 = torch.load('resnet18.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
minlen = min(len1, len2)
for n in range(minlen):
if model_dict1[model_list1[n]].shape != model_dict2[model_list2[n]].shape:
continue
model_dict1[model_list1[n]] = model_dict2[model_list2[n]]
model.load_state_dict(model_dict2)
完整的代码见自己搭建resnet18网络并加载torchvision自带权重
新增的改进代码
model_dict1 = torch.load('yolov5.pth')
model_dict2 = model.state_dict()
model_list1 = list(model_dict1.keys())
model_list2 = list(model_dict2.keys())
len1 = len(model_list1)
len2 = len(model_list2)
m, n = 0, 0
while True:
if m >= len1 or n >= len2:
break
layername1, layername2 = model_list1[m], model_list2[n]
w1, w2 = model_dict1[layername1], model_dict2[layername2]
if w1.shape != w2.shape:
continue
model_dict2[layername2] = model_dict1[layername1]
m += 1
n += 1
model.load_state_dict(model_dict2)
如果因为模型不匹配,运行第14行语句后,可看自己情况手动对m或n加上1。
补充:pytorch的一些坑:用预训练的vgg模型的部分层的特征报错,如张量不匹配
看代码吧~
#打算取VGG19的第二个全连接层的输出,那么就需要构建一个类,这个类要包含VGG的全部卷积层,
#以及到第二个全连接层的全部网络还有他们对应的参数
class Classification_att(nn.Module):
def __init__(self, rgb_range):
super(Classification_att, self).__init__()
self.vgg19 =models.vgg19(pretrained=True)
vgg = models.vgg19(pretrained=True).features
conv_modules = [m for m in vgg]
self.vgg_conv = nn.Sequential(*conv_modules[:37])
classfi = models.vgg19(pretrained=True).classifier
classif_modules = [n for n in classfi]
self.vgg_class = nn.Sequential(*classif_modules[:4])
vgg_mean = (0.485, 0.456, 0.406)
vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std)
for p in self.vgg_conv.parameters():
p.requires_grad = False
for p in self.vgg_class.parameters():
p.requires_grad = False
self.classifi = nn.Sequential(
nn.Linear(4096, 1024),
nn.ReLU(True),
nn.Linear(1024, 256),
nn.ReLU(True),
nn.Linear(256, 64),
)
def forward(self, x):
x = F.interpolate(x, size=[224, 224], scale_factor=None, mode='bilinear',
align_corners=False)
x = self.sub_mean(x)
x = self.vgg_conv(x)
x = self.vgg_class(x) #执行这部报错,说张量不匹配
原因是因为卷积层的输出不能直接连接全连接层,即使输出的张量的总的大小是一致的
查看vgg的pytorch源码发现是
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
#自己的代码没有torch.flatten(x, 1)这步
所以自己的少了一步
x = torch.flatten(x, 1)
补上就好了!
来源:https://blog.csdn.net/qq_34288751/article/details/114160725


猜你喜欢
- ndarray的转置(transpose)对于A是由np.ndarray表示的情况:可以直接使用命令A.T。也可以使用命令A.transpo
- 之前一直在windows环境使用pycharm加上virtualenv方式开发,最近由于本地多个virtualenv比较混乱,所以尝试切换a
- requests接口测试的介绍requests是一个很实用的Python HTTP客户端库,编写爬虫和测试服务器响应数据时经常会用到,Req
- 题目:1. 利用拉格朗日乘子法#导入sympy包,用于求导,方程组求解等等from sympy import * #设置变量x1 = sym
- 1.前言对于数据库引擎来说,内存是一个性能提升的重要解决手段。把数据缓存起来,可以避免在查询或更新数据时花费多余的时间,而这时间通常是从磁盘
- 本文实例讲述了python实现的多线程端口扫描功能。分享给大家供大家参考,具体如下:下面的程序给出了对给定的ip主机进行多线程扫描的Pyth
- 1.申明一个数组 var a[2] int 或者 a:=[2]int{1,2}2.数组索引数组就是索引的来建立如下图我们再来一个测试3.go
- 几天前,想把上个月校园招聘的餐旅费报销一下。结果在公司内网的报销系统折腾了三个半小时才搞定。看看自己报销的金额:802块。觉得挺无奈,花了三
- MongoDB是一个文档型数据库,是NOSQL家族中最重要的成员之一,以下代码封装了MongoDB的基本操作。MongoDBConfig.j
- golang扩容规则举个例子来演示下package mainimport ("fmt")func main() {arr
- 用鼠标创建小球,一个蹦来蹦去的解压小游戏…… 本次需要的外置包:pygame,pymu
- 写在前面其实我之前写过一个简单的识别手写数字的程序,但是因为逻辑比较简单,而且要求比较严苛,是在50x50大小像素的白底图上手写黑色数字,并
- 一、介绍数据库的约束是对表中数据进行的一种限制,为了保证数据的正确性、有效性、完整性。无论是在添加数据还是在删除数据的时候,都能提供帮助。所
- Multiplexer根据URL将请求路由给指定的Handler。Handler用于处理请求并给予响应。更严格地说,用来读取请求体、并将请求
- 这就意味着数据库和表名在 Windows 中是大小写不敏感的,而在大多数类型的 Unix 系统中是大小写敏感的。一个特例是 Mac OS X
- 一、预备知识1.1、JS数据类型基本数据类型:Boolean、String、Number、null、undefined引用数据类型:Obje
- <form action="calscore.asp?action=do" met
- PC端项目中经常会出现大量的数据列表页面,涉及到下拉框选择筛选条件;当时用到bootstrap-select下拉框时该如何点击重置按钮就清除
- 不要使用Logrus这其实和泛型有关。因为Go语言是一门强类型的静态语言,所以你不可能像NodeJS或者PHP那样绕过数据类型。那如果我们还
- 本文实例讲述了PHP获取当前相对于域名目录的方法。分享给大家供大家参考。具体如下:http://127.0.0.1/dev/classd/i