pytorch如何冻结某层参数的实现
作者:Pr4da 发布时间:2021-02-03 11:49:36
标签:pytorch,冻结,参数
在迁移学习finetune时我们通常需要冻结前几层的参数不参与训练,在Pytorch中的实现如下:
class Model(nn.Module):
def __init__(self):
super(Transfer_model, self).__init__()
self.linear1 = nn.Linear(20, 50)
self.linear2 = nn.Linear(50, 20)
self.linear3 = nn.Linear(20, 2)
def forward(self, x):
pass
假如我们想要冻结linear1层,需要做如下操作:
model = Model()
# 这里是一般情况,共享层往往不止一层,所以做一个for循环
for para in model.linear1.parameters():
para.requires_grad = False
# 假如真的只有一层也可以这样操作:
# model.linear1.weight.requires_grad = False
最后我们需要将需要优化的参数传入优化器,不需要传入的参数过滤掉,所以要用到filter()函数。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
其它的博客中都没有讲解filter()函数的作用,在这里我简单讲一下有助于更好的理解。
filter(function, iterable)
function: 判断函数
iterable: 可迭代对象
filter() 函数用于过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 list() 来转换。
该接收两个参数,第一个为函数,第二个为序列,序列的每个元素作为参数传递给函数进行判,然后返回 True 或 False,最后将返回 True 的元素放到新列表中。
filter()函数将requires_grad = True的参数传入优化器进行反向传播,requires_grad = False的则被过滤掉。
来源:https://blog.csdn.net/qq_40210586/article/details/103878155


猜你喜欢
- 详解Python import方法引入模块的实例在Python用import或者from…import或者from…import…as…来导
- Python for 和其他语言一样,也可以用来循环遍历对象,本文章向大家介绍Python for 循环的使用方法和实例,需要的朋友可与参考
- 1.第一种方法<table><tr><td>当前时间:</td><td id=&quo
- 本文简单介绍如何使用 Python 的 pyautogui 模块实现鼠标的自动移动以及键盘的自行输入. 该模块不是 Python 自带的,
- 目录一.准备工作二.预览1.启动2.添加城市3.展示多个城市天气三.设计流程1.获取城市天气信息过程四.源代码1.Weather_Tool-
- 1.在官网下载MySQL5.7安装包:mysql-5.7.20-linux-glibc2.12-x86_64.tar.gz。下载地址:htt
- 这里还以前面的微博为例,我们知道拖动刷新的内容由Ajax加载,而且页面的URL没有变化,那么应该到哪里去查看这些Ajax请求呢?1. 查看请
- 对数字货币的崛起感到新奇的我们,并且想知道其背后的技术——区块链是怎样实现的。 但是完全搞懂区块链并非易事,我喜欢在实践中学习,通
- 目录一、_用于临时变量1.1 REPL1.2 for循环中的_1.3 元组拆包中的_1.4 国际化函数1.5 大数字表示形式二、var_用于
- 本文更多将会介绍三思在日常中经常会用到的,或者虽然很少用到,但是感觉挺有意思的一些函数。分二类介绍,分别是: 著名函数篇-经常用到的函数 非
- 内容摘要:asp使用最多的就是ACCESS数据库和ms sql server数据库,本文列出了asp连接这两个数据库的方
- Python字符串字符串或串(String)是由数字、字母、下划线组成的一串字符。一般记为 :s="a1a2···an"
- Python argparse中的action=store_true用法前言Python的命令行参数解析模块学习。示例参数解析模块支持act
- 数据集数据集为Barcelona某段时间内的气象数据,其中包括温度、湿度以及风速等。本文将利用CNN来对风速进行预测。特征构造对于风速的预测
- 结合工作中的内容和大家分享一次Left Jon优化的过程,希望能给同学们新的思路。【功能背景】 我们需要
- Matplotlib编程实现import matplotlib.pyplot as pltimport numpy as npfrom ma
- 一、Sql Server中的日期与时间函数 1. 当前系统日期、时间 select getdate() 2. dateadd 在向指定日期加
- 最近在看python脚本语言,脚本语言是一种解释性的语言,不需要编译,可以直接用,由解释器来负责解释。python语言很强大,而且写起来很简
- COOKIE函数库:cookie.inc.php3 <?php if (!isset($__cookie_inc__)){ $__co
- 1. 背景golang 原生 json 包,在处理 json 对象的字段的时候,是需要严格匹配类型的。但是,实际上,当我们与一些老系统或者脚