pytorch中Schedule与warmup_steps的用法说明
作者:Bingoyear 发布时间:2023-07-07 00:18:14
1. lr_scheduler相关
lr_scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=num_train_optimization_steps)
其中args.warmup_steps可以认为是耐心系数
num_train_optimization_steps为模型参数的总更新次数
一般来说:
num_train_optimization_steps = int(total_train_examples / args.train_batch_size / args.gradient_accumulation_steps)
Schedule用来调节学习率,拿线性变换调整来说,下面代码中,step是当前迭代次数。
def lr_lambda(self, step):
# 线性变换,返回的是某个数值x,然后返回到类LambdaLR中,最终返回old_lr*x
if step < self.warmup_steps: # 增大学习率
return float(step) / float(max(1, self.warmup_steps))
# 减小学习率
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
在实际运行中,lr_scheduler.step()先将lr初始化为0. 在第一次参数更新时,此时step=1,lr由0变为初始值initial_lr;在第二次更新时,step=2,上面代码中生成某个实数alpha,新的lr=initial_lr *alpha;在第三次更新时,新的lr是在initial_lr基础上生成,即新的lr=initial_lr *alpha。
其中warmup_steps可以认为是lr调整的耐心系数。
由于有warmup_steps存在,lr先慢慢增加,超过warmup_steps时,lr再慢慢减小。
在实际中,由于训练刚开始时,训练数据计算出的grad可能与期望方向相反,所以此时采用较小的lr,随着迭代次数增加,lr线性增大,增长率为1/warmup_steps;迭代次数等于warmup_steps时,学习率为初始设定的学习率;迭代次数超过warmup_steps时,学习率逐步衰减,衰减率为1/(total-warmup_steps),再进行微调。
2. gradient_accumulation_steps相关
gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。
假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2
那么参数更新次数=24/6=4
现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4
在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。
补充:pytorch学习笔记 -optimizer.step()和scheduler.step()
optimizer.step()和scheduler.step()的区别
optimizer.step()通常用在每个mini-batch之中,而scheduler.step()通常用在epoch里面,但是不绝对,可以根据具体的需求来做。只有用了optimizer.step(),模型才会更新,而scheduler.step()是对lr进行调整。
通常我们有
optimizer = optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 100, gamma = 0.1)
model = net.train(model, loss_function, optimizer, scheduler, num_epochs = 100)
在scheduler的step_size表示scheduler.step()每调用step_size次,对应的学习率就会按照策略调整一次。
所以如果scheduler.step()是放在mini-batch里面,那么step_size指的是经过这么多次迭代,学习率改变一次。
来源:https://blog.csdn.net/angel_hben/article/details/104538634


猜你喜欢
- 在SQL Server中Count(*)或者Count(1)或者Count([列])或许是最常用的聚合
- 前面简单介绍了Python字符串基本操作,这里再来简单讲述一下Python列表相关操作1. 基本定义与判断>>> dir(
- 方法来源于土豆网的导航,在这里纪录一下实现的思路。主要是利用 position 属性的 absolute 和 relative 配
- 本文实例为大家分享了H5+css3+js搭建带验证码的登录页面,供大家参考,具体内容如下login.html<!DOCTYPE HTM
- 首先是建表语句 CREATE TABLE `t_address_province` ( `id` INT AUTO_INCREMENT PR
- 模版结构优化引入模版有时候一些代码是在许多模版中都用到的。如果我们每次都重复的去拷贝代码那肯定不符合项目的规范。一般我们可以把这些重复性的代
- 今天写Python程序上传图片需要用到PIL库,于是到http://www.pythonware.com/products/pil/#pil
- orm表单使用csrfa. 基本应用form表单中添加{% csrf_token %}b. 全站禁用# 'django.middle
- 简介:轮廓发现是基于图像边缘提取的基础寻找对象轮廓的方法,所以边缘提取的阈值选定会影响最终轮廓发现结果。代码如下:import cv2 as
- 上篇文章讲了js中的一些概念(词法结构) 和 数据类型(部分)。这章我们 继续.然后了解下js中操作数据 和 函数的 作用域。1,对象跟基本
- 前言最近发现有些东西长时间不用就要忘了,坚持每天复习总结一个小知识点吧~异常是什么呢?就是在代码执行过程中非预期的执行结果,随着代码越来越复
- 本文实例为大家分享了python3实现人脸识别的具体代码,供大家参考,具体内容如下第一种:import cv2import numpy as
- 换了N种字符串连接的方法,终于连接上去了。 共享下用的 Provider=SQLOLEDB.1; User ID=sa; Password=
- 前言:keras是一个十分便捷的开发框架,为了更好的追踪网络训练过程中的损失函数loss和准确率accuracy,我们有几种处理方式,第一种
- 本文实例讲述了Python实现的寻找前5个默尼森数算法。分享给大家供大家参考,具体如下:找前5个默尼森数。若P是素数且M也是素数,并且满足等
- permute(dims)将tensor的维度换位。参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimensio
- 首先,在数据库中创建一个表,用于存放图片:CREATE TABLE Images(Id INT PRIMARY KEY AUTO_INCRE
- 本文实例分析了python对json的相关操作。分享给大家供大家参考,具体如下:什么是json:JSON(JavaScript Object
- 本文实例讲述了python函数局部变量用法。分享给大家供大家参考。具体分析如下:当你在函数定义内声明变量的时候,它们与函数外具有相同名称的其
- 限流器是服务中非常重要的一个组件,在网关设计、微服务、以及普通的后台应用中都比较常见。它可以限制访问服务的频次和速率,防止服务过载,被刷爆。