pytorch fine-tune 预训练的模型操作
作者:This is bill 发布时间:2023-05-02 01:05:25
之一:
torchvision 中包含了很多预训练好的模型,这样就使得 fine-tune 非常容易。本文主要介绍如何 fine-tune torchvision 中预训练好的模型。
安装
pip install torchvision
如何 fine-tune
以 resnet18 为例:
from torchvision import models
from torch import nn
from torch import optim
resnet_model = models.resnet18(pretrained=True)
# pretrained 设置为 True,会自动下载模型 所对应权重,并加载到模型中
# 也可以自己下载 权重,然后 load 到 模型中,源码中有 权重的地址。
# 假设 我们的 分类任务只需要 分 100 类,那么我们应该做的是
# 1. 查看 resnet 的源码
# 2. 看最后一层的 名字是啥 (在 resnet 里是 self.fc = nn.Linear(512 * block.expansion, num_classes))
# 3. 在外面替换掉这个层
resnet_model.fc= nn.Linear(in_features=..., out_features=100)
# 这样就 哦了,修改后的模型除了输出层的参数是 随机初始化的,其他层都是用预训练的参数初始化的。
# 如果只想训练 最后一层的话,应该做的是:
# 1. 将其它层的参数 requires_grad 设置为 False
# 2. 构建一个 optimizer, optimizer 管理的参数只有最后一层的参数
# 3. 然后 backward, step 就可以了
# 这一步可以节省大量的时间,因为多数的参数不需要计算梯度
for para in list(resnet_model.parameters())[:-2]:
para.requires_grad=False
optimizer = optim.SGD(params=[resnet_model.fc.weight, resnet_model.fc.bias], lr=1e-3)
...
为什么
这里介绍下 运行resnet_model.fc= nn.Linear(in_features=..., out_features=100)时 框架内发生了什么
这时应该看 nn.Module 源码的 __setattr__ 部分,因为 setattr 时都会调用这个方法:
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
首先映入眼帘就是 remove_from 这个函数,这个函数的目的就是,如果出现了 同名的属性,就将旧的属性移除。 用刚才举的例子就是:
预训练的模型中 有个 名字叫fc 的 Module。
在类定义外,我们 将另一个 Module 重新 赋值给了 fc。
类定义内的 fc 对应的 Module 就会从 模型中 删除。
之二:
前言
这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。
参数初始化
参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。
所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m):
# 使用isinstance来判断m属于什么类型
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播
m.weight.data.fill_(1)
m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调
有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)
# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调
有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,
model.parameters())
optimizer = torch.optim.SGD([
{'params': base_params},
{'params': model.fc.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
之三:
pytorch finetune模型
文章主要讲述如何在pytorch上读取以往训练的模型参数,在模型的名字已经变更的情况下又如何读取模型的部分参数等。
pytorch 模型的存储与读取
其中在模型的保存过程有存储模型和参数一起的也有单独存储模型参数的
单独存储模型参数
存储时使用:
torch.save(the_model.state_dict(), PATH)
读取时:
the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))
存储模型与参数
存储:
torch.save(the_model, PATH)
读取:
the_model = torch.load(PATH)
模型的参数
fine-tune的过程是读取原有模型的参数,但是由于模型的所要处理的数据集不同,最后的一层class的总数不同,所以需要修改模型的最后一层,这样模型读取的参数,和在大数据集上训练好下载的模型参数在形式上不一样。需要我们自己去写函数读取参数。
pytorch模型参数的形式
模型的参数是以字典的形式存储的。
model_dict = the_model.state_dict(),
for k,v in model_dict.items():
print(k)
即可看到所有的键值
如果想修改模型的参数,给相应的键值赋值即可
model_dict[k] = new_value
最后更新模型的参数
the_model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是一样的
我们可以通过下列算法进行读取模型
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
# 1. filter out unnecessary keys
diff = {k: v for k, v in model_dict.items() if \
k in pretrained_dict and pretrained_dict[k].size() == v.size()}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
pretrained_dict.update(diff)
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是一样的
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path)
keys = []
for k,v in pretrained_dict.items():
keys.append(k)
i = 0
for k,v in model_dict.items():
if v.size() == pretrained_dict[keys[i]].size():
print(k, ',', keys[i])
model_dict[k]=pretrained_dict[keys[i]]
i = i + 1
model.load_state_dict(model_dict)
如果模型的key值和在大数据集上训练时的key值是不一样的,但是顺序是也不一样的
自己找对应关系,一个key对应一个key的赋值
来源:https://blog.csdn.net/Scythe666/article/details/82809615


猜你喜欢
- CREATE TABLE table1( [ID] [bigint] IDENTITY(1,1) NOT NULL, [Name] [nva
- 一、避免Firefox 背景图不显示的兼容问题,定义background 属性,先后顺序不能随意变动。background : backgr
- 模块的作用是:允许从任何文件里得到任何一行或几行,并且使用缓存进行优化。有几个API接口linecache.getlines(filenam
- 本文将梳理github上最火的wechat_jump_game的实现思路,并解析其图像处理部分源码首先废话少说先看效果 核心思想获取棋子到下
- 如下所示:import threadingimport timesem=threading.Semaphore(4) #限制线程的最大数量为
- Python 字符串字符串是 Python 中最常用的数据类型。我们可以使用引号来创建字符串。创建字符串很简单,只要为变量分配一个值即可。例
- 本文实例讲述了Python基于dom操作xml数据的方法。分享给大家供大家参考,具体如下:1、xml的内容为del.xml,如下<?x
- 如下所示:list = [‘a','b','c']想用for循环输出list的元素以及对应的索引。代
- 一、概念 1. 数据库 (Database)什么是数据库?数据库是依照某种数据模型组织起来并存放二级存储器中的数据集合。这种数据集合具有如下
- 概述从今天开始, 小白我将带领大家一起来补充一下 数据库的知识.自连接自连接 (Self Join) 是一种特殊的表连接. 自连接指相互连接
- 这一款是用原生javascript实现的分页插件pagenav,页码显示jquery插件,只需要存在#pageNav,则会在其中显示页码,调
- 具体代码如下所示:#字符串反转def reverse (s): rt = '' for i in r
- 在机器学习过程中,通常会通过pandas读取csv文件,保持成dadaframe格式,然而有时候需要对dataframe中的时间字段进行数据
- 七牛云存储的 Python 语言版本 SDK(本文以下称 Python-SDK)是对七牛云存储API协议的一层封装,以提供一套对于 Pyth
- 不知道用ASP写代码的朋友是不是和我有一样的感受,ASP中最头疼的就是调试程序的时候不方便,我想可能很多朋友都会用这样的方法&ldq
- 在进行单个爬虫抓取的时候,我们不可能按照一次抓取一个url的方式进行网页抓取,这样效率低,也浪费了cpu的资源。目前python上面进行并发
- pytest官方文档fixtures调用既然fixtures是给执行测试做准备工作的,那么pytest如何知道哪些测试函数 或者 fixtu
- <%@ Language=VBScript %><%Option Explicit %><%Dim strUR
- 本文实例为大家分享了opencv矿石图片检测矿石数量的具体代码,供大家参考,具体内容如下原始矿石图片此类图片是高躁图,二值化后图像如下采用膨
- 一、问题描述最近遇到一个问题,也就是使用分区表进行数据查询/加载的时候比普通表的性能下降了约50%,主要瓶颈出现在CPU,既然是CPU瓶颈理