pytorch载入预训练模型后,实现训练指定层
作者:慕白- 发布时间:2022-01-22 12:40:35
1、有了已经训练好的模型参数,对这个模型的某些层做了改变,如何利用这些训练好的模型参数继续训练:
pretrained_params = torch.load('Pretrained_Model')
model = The_New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(), strict=False)
strict=False 使得预训练模型参数中和新模型对应上的参数会被载入,对应不上或没有的参数被抛弃。
2、如果载入的这些参数中,有些参数不要求被更新,即固定不变,不参与训练,需要手动设置这些参数的梯度属性为Fasle,并且在optimizer传参时筛选掉这些参数:
# 载入预训练模型参数后...
for name, value in model.named_parameters():
if name 满足某些条件:
value.requires_grad = False
# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数帅选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。
3、如果载入的这些参数中,所有参数都更新,但要求一些参数和另一些参数的更新速度(学习率learning rate)不一样,最好知道这些参数的名称都有什么:
# 载入预训练模型参数后...
for name, value in model.named_parameters():
print(name)
# 或
print(model.state_dict().keys())
假设该模型中有encoder,viewer和decoder两部分,参数名称分别是:
'encoder.visual_emb.0.weight',
'encoder.visual_emb.0.bias',
'viewer.bd.Wsi',
'viewer.bd.bias',
'decoder.core.layer_0.weight_ih',
'decoder.core.layer_0.weight_hh',
假设要求encode、viewer的学习率为1e-6, decoder的学习率为1e-4,那么在将参数传入优化器时:
ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{'params':base_params,'lr':1e-6},
{'params':model.decoder.parameters()}
],
lr=1e-4, momentum=0.9)
代码的结果是除decoder参数的learning_rate=1e-4 外,其他参数的额learning_rate=1e-6。
在传入optimizer时,和一般的传参方法torch.optim.Adam(model.parameters(), lr=xxx) 不同,参数部分用了一个list, list的每个元素有params和lr两个键值。如果没有 lr则应用Adam的lr属性。Adam的属性除了lr, 其他都是参数所共有的(比如momentum)。
参考:
pytorch官方文档
https://www.jb51.net/article/134943.htm
来源:https://blog.csdn.net/weixin_36049506/article/details/89522860


猜你喜欢
- mitmproxy有3中监听请求与响应的方式:mitmproxy控制台方式mitmdump与Python对接的方式mitmweb可视化方式前
- JavaScript的对象都是实例化了的,只可以使用而不能够创建继承于这些对象的新的子类. window对象为所有对象的Parent win
- 1. 首先导入一些python画图的包,读取txt文件,假设我现在有两个模型训练结果的records.txt文件import numpy a
- 前言在开始本文之前,先来介绍一下相关内容,大家都知道一些防护SSRF漏洞的代码一般使用正则来判断访问IP是否为内部IP,比如下面这段网上比较
- 1、简介AI 聊天机器人使用自然语言处理 (NLP) 来帮助用户通过文本、图形或语音与 Web 服务或应用进行交互。聊天机器人可以理解自然人
- pycharm常用快捷键1、编辑(Editing)Ctrl + Space基本的代码完成(类、方法、属性)Ctrl + Alt + Spac
- 本节要讲解如下图所示的滑块验证码(更为复杂的滑动拼图验证码在下一篇介绍)。这种验证码机制比较简单:将滑块拖动到滑轨的最右端即可完成验证,如下
- items()方法返回字典的(键,值)元组对的列表语法以下是items()方法的语法:dict.items()参数 &
- 内容摘要:在网页制作中,有许多的术语,例如:CSS、HTML、DHTML、XHTML等等。在下面的文章中我们将会用到一些有关于HTML的基本
- MySQL查询语句大家都在用,但是应该如何设计高效合理的MySQL查询语句呢?下面就教您MySQL查询语句的合理设计方法,分享给大家学习学习
- 本文实例讲述了Python单体模式的几种常见实现方法。分享给大家供大家参考,具体如下:这里python实现的单体模式,参考了:https:/
- 本篇介绍在执行MySQL线上变更时遇到的问题,表现为"更新JSON字段时,实际更新的值与SQL语句中的值不一致,JSON格式错误&
- 在python中,文件使用十分频繁,本文将向大家介绍python文件路径的操作:得到指定文件路径、得到当前文件名、判断文件路径是否存在、获得
- opencv-python打开USB或笔记本前置摄像头代码其中video_index是摄像头编号,一般前置摄像头为0,USB摄像头为1或2.
- 目录索引模型B+Tree索引选择索引优化索引选择性覆盖索引最左前缀原则+索引下推前缀索引唯一索引索引失效总结索引模型哈希表适用于只有等值查询
- Python是静态作用域语言,但是它自身是一个动态语言。在Python中变量的作用域是由变量在代码中的位置决定的,与C语言有些相似,但不是完
- 本文实例讲述了PHP实现获取第一个中文首字母并进行排序的方法。分享给大家供大家参考,具体如下:最近在做储值结算,需求里结算首页需要按门店的首
- 在机器学习或者深度学习中,我们常常碰到一个问题是数据集的切分。比如在一个比赛中,举办方给我们的只是一个带标注的训练集和不带标注的测试集。其中
- 本文实例讲述了Python数据结构之图的应用。分享给大家供大家参考,具体如下:一、图的结构二、代码# -*- coding:utf-8 -*
- 一个不错的二级联动下拉菜单源码,您一定会用得到的。运行代码:<html><head><title>Lis