pytorch载入预训练模型后,实现训练指定层
作者:慕白- 发布时间:2022-01-22 12:40:35
1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:
pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)
strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。
2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:
# 载入预训练模型参数后...
for name, value in model.named_parameters():
if name 满足某些条件:
value.requires_grad = False
# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。
3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:
# 载入预训练模型参数后...
for name, value in model.named_parameters():
print(name)
# 或
print(model.state_dict().keys())
假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:
'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',
假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:
ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
{'params':model.decoder.parameters()}
],
lr=1e-4, momentum=0.9)
代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。
在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。
参考:
pytorch官方文档
https://www.jb51.net/article/134943.htm
来源:https://blog.csdn.net/weixin_36049506/article/details/89522860
猜你喜欢
- 如果你细心跟踪一下SQL Server数据库服务器的登录过程,你会发现口令计算其实是非常脆弱的,SQL Server数据库的口令脆弱体现两方
- 今天下午,低一度博客受到攻击了,出现了大约一个小时的访问异常。庆幸的是,这帮无耻歹徒没能成功获取我的Access数据库,而只是象征性地给我注
- 需 求 分 析 1、读取指定目录下的所有文件2、读取指定文件,输出文件内容3、创建一个文件并保存到指定目录实 现 过 程Python写代码简
- Pandas查询数据的几种方法df.loc方法,根据行、列的标签值查询df.iloc方法,根据行、列的数字位置查询df.where方法df.
- asp按关键字查询XML的问题 '-------------------------------------------------
- Python是一门面向对象的语言,定义类时经常要用到继承,在类的继承中,子类继承父类中已经封装好的方法,不需要再次编写,如果子类如果重新定义
- 一、先让飞机在屏幕上飞起来吧。(一)实现飞机类class Plane: def __init__(self,fil
- 最近,在项目开发过程中,碰到了数据库死锁问题,在解决问题的过程中,笔者对MySQL InnoDB引擎锁机制的理解逐步加深。案例如下:在使用S
- 案例:爬取使用搜狗根据指定词条搜索到的页面数据(例如爬取词条为‘周杰伦'的页面数据)import urllib.request# 1
- 本文实例为大家分享了python实现简易学生信息管理系统的具体代码,供大家参考,具体内容如下一、系统功能1.录入学生信息2.查找学生信息3.
- 前言众所周知,python拥有丰富的内置库,还支持众多的第三方库,被称为胶水语言,随机函数库random,就是python自带的标准库,他的
- 目录一、简介1、优势2、劣势二、安装三、locust的库和方法介绍1、from locust import task2、from locus
- 本文实例讲述了PHPExcel冻结(锁定)表头的简单实现方法。分享给大家供大家参考,具体如下:PHPExcel是一款功能比较强大的操作微软e
- 实现对下一个单词的预测RNN 原理自己找,这里只给出简单例子的实现代码import tensorflow as tfimport numpy
- 本文实例为大家分享了python实现学生信息管理系统的具体代码,含代码注释、增删改查、排序、统计显示学生信息,供大家参考,具体内容如下运行如
- 英文原文:http://www.smashingmagazine.com/2008/08/18/译文原文:http://blog.bingo
- 前言:以往看到我博客的小伙伴可能都知道,我的前言一般都是吐槽和讲废话环节,哈哈哈哈。今天难得休息,最近可真是太忙了,博主已经连续一年都在99
- 今天介绍Python当中十大可视化工具,每一个都独具特色,惊艳一方。MatplotlibMatplotlib 是 Python 的一个绘图库
- 阅读《YUI学习笔记(1)》《YUI学习笔记(2)》YAHOO.lang.later,YAHOO.lang.trim,YAHOO.lang.
- 前端代码要做到简洁易读、高效,还要考虑后端嵌套的方便性。前段时间做了一个导航,把整个制作过程重现,希望对大家有帮助。看到这样的导航,你会怎么