PyTorch中model.zero_grad()和optimizer.zero_grad()用法
作者:血雨腥风霜 发布时间:2023-09-15 20:13:57
废话不多说,直接上代码吧~
model.zero_grad()
optimizer.zero_grad()
首先,这两种方式都是把模型中参数的梯度设为0
当optimizer = optim.Optimizer(net.parameters())时,二者等效,其中Optimizer可以是Adam、SGD等优化器
def zero_grad(self):
"""Sets gradients of all model parameters to zero."""
for p in self.parameters():
if p.grad is not None:
p.grad.data.zero_()
补充知识:Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解
引言
一般训练神经网络,总是逃不开optimizer.zero_grad之后是loss(后面有的时候还会写forward,看你网络怎么写了)之后是是net.backward之后是optimizer.step的这个过程。
real_a, real_b = batch[0].to(device), batch[1].to(device)
fake_b = net_g(real_a)
optimizer_d.zero_grad()
# 判别器对虚假数据进行训练
fake_ab = torch.cat((real_a, fake_b), 1)
pred_fake = net_d.forward(fake_ab.detach())
loss_d_fake = criterionGAN(pred_fake, False)
# 判别器对真实数据进行训练
real_ab = torch.cat((real_a, real_b), 1)
pred_real = net_d.forward(real_ab)
loss_d_real = criterionGAN(pred_real, True)
# 判别器损失
loss_d = (loss_d_fake + loss_d_real) * 0.5
loss_d.backward()
optimizer_d.step()
上面这是一段cGAN的判别器训练过程。标题中所涉及到的这些方法,其实整个神经网络的参数更新过程(特别是反向传播),具体是怎么操作的,我们一起来探讨一下。
参数更新和反向传播
上图为一个简单的梯度下降示意图。比如以SGD为例,是算一个batch计算一次梯度,然后进行一次梯度更新。这里梯度值就是对应偏导数的计算结果。显然,我们进行下一次batch梯度计算的时候,前一个batch的梯度计算结果,没有保留的必要了。所以在下一次梯度更新的时候,先使用optimizer.zero_grad把梯度信息设置为0。
我们使用loss来定义损失函数,是要确定优化的目标是什么,然后以目标为头,才可以进行链式法则和反向传播。
调用loss.backward方法时候,Pytorch的autograd就会自动沿着计算图反向传播,计算每一个叶子节点的梯度(如果某一个变量是由用户创建的,则它为叶子节点)。使用该方法,可以计算链式法则求导之后计算的结果值。
optimizer.step用来更新参数,就是图片中下半部分的w和b的参数更新操作。
来源:https://blog.csdn.net/weixin_41466947/article/details/89203231


猜你喜欢
- 相信使用过MFC编程的朋友对CString这个类的印象应该非常深刻吧?的确,MFC中的CString类使用起来真的非常的方便好用。但是如果离
- SQL注入语句有时候会使用替换查询技术,就是让原有的查询语句查不到结果出错,而让自己构造的查询语句执行,并把执行结果代替原有查询语句查询结果
- 计时器用来定时执行任务,分享一段代码:package mainimport "time"import "fmt
- 前做PPT要用到折线图,嫌弃EXCEL自带的看上去不好看,就用python写了一个画折线图的程序。import matplotlib.pyp
- logging日志模块:是用来记录日志的模块,一般记录用户在软件中的操作使用方法:模板直接拿来用,手动修改# logging的配置信息(模板
- 当感觉mysql性能出现问题时,通常会先看下当前mysql的执行状态,使用 show processlist 来查看,例如:其中state状
- <script> Function.prototype.createInstance = function(){ var T =
- 微博上讨论MySQL在删除大表engine=innodb(30G+)时,如何减少MySQL hang的时间,现做一下简单总结: 当buffe
- 关于PHP调用Python数据传输问题这是以前大学时做项目出现的问题,现在把它挪上来,希望给遇到问题的未来大佬给出一些小的思路,请大佬们不要
- 数据安全是任何数据服务解决方案中的一个关键要求,而Windows Server 2008和SQL Server 2008结合起来,通过一个基
- 接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。以下是提取一张jpg图像
- 质数又称素数。指在一个大于1的自然数中,除了1和此整数自身外,不能被其他自然数整除的数。素数在数论中有着很重要的地位。比1大但不是素数的数称
- 下面是一份在 HTML 4 Strict 和 XHTML 1.0 Strict 下必须遵守的标签嵌套规则,比如你不能在 <a>
- 异常的本质导引问题在实际工作中,我们遇到的问题都不是完美的,比如:你写某个模块,用户输入不一定符合你的要求:你的程序要打开某个文件,这个文件
- linecache, 可以用它方便地获取某一文件某一行的内容。而且它也被 traceback 模块用来获取相关源码信息来展示。用法很简单:&
- Xxencode编码,也是一个二进制字符转换为普通打印字符方法。跟UUencode编码原理方法很相似,唯独不同的是可打印字符不同。通个UUe
- 1.决定大小写是否敏感的参数在 MySQL 中,数据库与 data 目录中的目录相对应。数据库中的每个表都对应于数据库目录中的至少一个文件(
- 前言本文旨在记录使用Flask框架过程中与前端Vue对接过程中,存在WebSocket总是连接失败导致前端取不到数据的问题。以及在使用Web
- 本文实例讲述了微信小程序学习笔记之本地数据缓存功能。分享给大家供大家参考,具体如下:前面介绍了微信小程序获取位置信息操作。这里再来介绍一下微
- 与Channel区别Channel能够很好的帮助我们控制并发,但是在开发习惯上与显示的表达不太相同,所以在Go语言中可以利用sync包中的W