pytorch fine-tune 预训练的模型操作
作者:This is bill 发布时间:2023-05-02 01:05:25
之一:
torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。
安装
pip install torchvision
如何 fine-tune
以 resnet18 为例:
from torchvision import models
from torch import nn
from torch import optim
resnet_model = models.resnet18(pretrained=True)
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
para.requires_grad=False
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
...
为什么
这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么
这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:
预训练的模型中 有个 名字叫fc 的 Module。
在类定义外,我们 将另一个 Module 重新 赋值给了 fc。
类定义内的 fc 对应的 Module 就会从 模型中 删除。
之二:
前言
这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。
参数初始化
参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。
所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调
有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调
有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
model.parameters())
optimizer = torch.optim.SGD([
{'params': base_params},
{'params': model.fc.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
之三:
pytorch finetune模型
文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。
pytorch 模型的存储与读取
其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的
单独存储模型参数
存储时使用:
torch.save(the_model.state_dict(), PATH)
读取时:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
存储模型与参数
存储:
torch.save(the_model, PATH)
读取:
the_model = torch.load(PATH)
模型的参数
fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。
pytorch模型参数的形式
模型的参数是以字典的形式存储的。
model_dict = the_model.state_dict(),
for k,v in model_dict.items():
print(k)
即可看到所有的键值
如果想修改模型的参数,给相应的键值赋值即可
model_dict[k] = new_value
最后更新模型的参数
the_model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值
来源:https://blog.csdn.net/Scythe666/article/details/82809615
猜你喜欢
- 本文实例讲述了CentOS环境下安装Redis3.0及phpredis扩展测试。分享给大家供大家参考,具体如下:线上的统一聊天及推送系统re
- 在我们的生活中,只要你睁开眼睛就能看到各种各样的视觉。不同的视觉能给你不同的视觉暗示,同样能给你不同的心理感受。视觉这个话题太泛了,大自然中
- 对于任何一个开发项目来说最大的错误可能就是没有计划。最近,有些人认为开始前无需计划,一个优秀的开发者需要的是随机应变。我敢肯定这样的做法最后
- 数据库复制:简单来说,数据库复制就是由两台服务器,主服务器和备份服务器,主服务器修改后,备份服务器自动修改。复制的模式有两种:推送模式和请求
- 引言语音端点检测最早应用于电话传输和检测系统当中,用于通信信道的时间分配,提高传输线路的利用效率.端点检测属于语音处理系统的前端操作,在语音
- Turtle库是Python内置的图形化模块,属于标准库之一,位于Python安装目录的lib文件夹下,常用函数有以下几种:画笔控制函数pe
- 前言我们在django-rest-framework 自定义swagger 文章中编写了接口, 调通了接口文档. 接口文档可以直接填写参数进
- 问题:如何保护自己的ASP源代码不泄露? 答:下载微软的Windows Script Encoder,对ASP的脚本和客户端javascri
- 通常,当一个页面有太多信息要显示,而一页塞又不下所有信。为了请求速度、美观以及其他的各种理由,分页就会被我们请过来。让我们的用户可以选择是否
- <html> <head> <title>Untitled Document</title>
- 回车和换行的历史:机械打字机有回车和换行两个键作用分别是:换行就是把滚筒卷一格,不改变水平位置。 (即移到下一行,但不是行首,而是和上一行水
- 在ASP中,直接使用“Insert into” 语句与使用ADO中AddNew方法有什么区别?哪一种更好呢?AddNew方法的实质就是封装了
- 首先获取ip:<% userip=Request.ServerVariables(&qu
- 这份数据集来源于Kaggle,数据集有12500只猫和12500只狗。在这里简单介绍下整体思路处理数据设计神经网络进行训练测试1. 数据处理
- SQL2005的存储过程: set ANSI_NULLS ON set QUOTED_IDENTIFIER ON go ALTER PROC
- 使用MySQL进行数据库备份,有很正规的数据库备份方法,同其他的数据库服务器有相同的概念,但有没有想过,MySQL会有更简捷的使用文件目录的
- 英文原文:http://www.usabilitypost.com/2009/04/15/8-characteristics-of-succ
- 网上有很多提供在线按钮制作、文字标题制作、Logo制作服务的网站,它们可以非常方便了让大家轻松的获得效果出色的各类图标型的图片,下面就快来看
- 一、 软件介绍 DB2MYSQL是一个可以自动将ACCESS数据库文件转化为对应的SQL代码的软件。可广泛应用于ACCESS数据库转换为MY
- 开发工具**Python版本:**3.6.4相关模块:pyecharts模块;以及一些Python自带的模块。环境搭建安装Python并添加