聊聊pytorch中Optimizer与optimizer.step()的用法
作者:wang xiang 发布时间:2022-03-16 22:45:34
当我们想指定每一层的学习率时:
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所有的参数。
进行单次优化
所有的optimizer都实现了step()方法,这个方法会更新所有的参数。它能按两种方式来使用:
optimizer.step()
这是大多数optimizer所支持的简化版本。一旦梯度被如backward()之类的函数计算好后,我们就可以调用这个函数。
例子
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.step(closure)
一些优化算法例如Conjugate Gradient和LBFGS需要重复多次计算函数,因此你需要传入一个闭包去允许它们重新计算你的模型。这个闭包应当清空梯度,计算损失,然后返回。
例子:
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
补充:Pytorch optimizer.step() 和loss.backward()和scheduler.step()的关系与区别
首先需要明确optimzier优化器的作用, 形象地来说,优化器就是需要根据网络反向传播的梯度信息来更新网络的参数,以起到降低loss函数计算值的作用,这也是机器学习里面最一般的方 * 。
从优化器的作用出发,要使得优化器能够起作用,需要主要两个东西:
1. 优化器需要知道当前的网络或者别的什么模型的参数空间
这也就是为什么在训练文件中,正式开始训练之前需要将网络的参数放到优化器里面,比如使用pytorch的话总会出现类似如下的代码:
optimizer_G = Adam(model_G.parameters(), lr=train_c.lr_G) # lr 使用的是初始lr
optimizer_D = Adam(model_D.parameters(), lr=train_c.lr_D)
2. 需要知道反向传播的梯度信息
我们还是从代码入手,如下所示是Pytorch 中SGD优化算法的step()函数具体写法,具体SGD的写法放在参考部分。
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = d_p.clone()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)
if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
p.data.add_(-group['lr'], d_p)
return loss
从上面的代码可以看到step这个函数使用的是参数空间(param_groups)中的grad,也就是当前参数空间对应的梯度,这也就解释了为什么optimzier使用之前需要zero清零一下,因为如果不清零,那么使用的这个grad就得同上一个mini-batch有关,这不是我们需要的结果。
再回过头来看,我们知道optimizer更新参数空间需要基于反向梯度,因此,当调用optimizer.step()的时候应当是loss.backward()的时候,这也就是经常会碰到,如下情况
total_loss.backward()
optimizer_G.step()
loss.backward()在前,然后跟一个step。
那么为什么optimizer.step()需要放在每一个batch训练中,而不是epoch训练中,这是因为现在的mini-batch训练模式是假定每一个训练集就只有mini-batch这样大,因此实际上可以将每一次mini-batch看做是一次训练,一次训练更新一次参数空间,因而optimizer.step()放在这里。
scheduler.step()按照Pytorch的定义是用来更新优化器的学习率的,一般是按照epoch为单位进行更换,即多少个epoch后更换一次学习率,因而scheduler.step()放在epoch这个大循环下。
来源:https://blog.csdn.net/qq_40178291/article/details/99963586
猜你喜欢
- * 前,我在公司做设计,当时就已经做到技术总监,Photoshop是自学的,当时觉得全世界比我Photoshop强的人也不在多数。七年前,
- 高效的css写法中的一条就是使用简写。通过简写可以让你的CSS文件更小,更易读。而了解CSS属性简写也是前端开发工程师的基本功之一。今天我们
- 在我们平常使用Python进行数据处理与分析时,在import完一大堆库之后,就是对数据进行预览,查看数据是否出现了缺失值、重复值等异常情况
- 在使用Matlab肯定会碰到Matlab求解数组中的最大值以及它所在的位置的问题。博主开始用循环的方法找,既浪费时间又消耗资源,后面查找后才
- 参数数量及其作用该函数共有五个参数,分别是:被赋值的变量 ref要分配给变量的值 value、是否验证形状 validate_shape是否
- 毫无疑问,JavaScript 是一种非常灵活的脚本语言,有时候它像一只难以驯服的野马——你受益于它的灵活性的同时,也要时刻提防它变得失去控
- 前言进程之间通信与线程同步是一个历久弥新的话题,对编程稍有了解应该都知道,但是细说又说不清。一方面除了工作中可能用的比较少,另一方面就是这些
- 新建图像文件后选Channels面板,新建Alpha1通道; 做压
- 如下所示:两个函数:Basemap.drawparallels ##纬度 Basemap.drawmeridia
- python的hashlib库中提供的hexdigest返回长度32的字符串。直接通过digest返回的16字节,有不可打印字符。问题来了,
- 阅读上一篇教程:WEB2.0网页制作标准教程(10)自适应高度布局初步搭建起来,我开始填充里面的内容。首先是定义logo图片:样式表:#lo
- access中可以将文本中的数据轻松导入表中,mysql中用起来没那么方便,其实起来也很简单。首先将数据记录按行处理好用特定的字符分开如:“
- 主要功能在copyFiles()函数里实现,如下:def copyFiles(src, dst): sr
- 雪花算法是在一个项目体系中生成全局唯一ID标识的一种方式,偶然间看到了Python使用雪花算法不尽感叹真的是太便捷了。它生成的唯一ID的规则
- 1. 排名函数与PARTITION BY --所有数据 SELECT * FROM dbo.student AS a INNER JOIN
- 目前可实现:MD5算法、SHA256算法、先MD5后SHA256、先SHA256后MD5、两次MD5、两次SHA256、前8位MD5算法后8
- 一、套接字套接字是为特定网络协议(例如TCP/IP,ICMP/IP,UDP/IP等)套件对上的网络应用程序提供者提供当前可移植标准的对象。它
- 一、多线程同步由于CPython的python解释器在单线程模式下执行,所以导致python的多线程在很多的时候并不能很好地发挥多核cpu的
- 手头有 109 张头部 CT 的断层扫描图片,我打算用这些图片尝试头部的三维重建。基础工作之一,就是要把这些图片数据读出来,组织成一个三维的
- 内容适应形式学习了死猫的文章,我今天也来说说有关内容和容器的关系。看标题你也许觉得有些囧,它和上一篇《形式追随内容?》看起来相反,而且好像从