浅谈pytorch中为什么要用 zero_grad() 将梯度清零
作者:小小鼠标0 发布时间:2022-10-02 11:24:18
pytorch中为什么要用 zero_grad() 将梯度清零
调用backward()函数之前都要将梯度清零,因为如果梯度不清零,pytorch中会将上次计算的梯度和本次计算的梯度累加。
这样逻辑的好处是,当我们的硬件限制不能使用更大的bachsize时,使用多次计算较小的bachsize的梯度平均值来代替,更方便,坏处当然是每次都要清零梯度。
optimizer.zero_grad()
output = net(input)
loss = loss_f(output, target)
loss.backward()
补充:Pytorch 为什么每一轮batch需要设置optimizer.zero_grad
CSDN上有人写过原因,但是其实写得繁琐了。
根据pytorch中的backward()函数的计算,当网络参量进行反馈时,梯度是被积累的而不是被替换掉;但是在每一个batch时毫无疑问并不需要将两个batch的梯度混合起来累积,因此这里就需要每个batch设置一遍zero_grad 了。
其实这里还可以补充的一点是,如果不是每一个batch就清除掉原有的梯度,而是比如说两个batch再清除掉梯度,这是一种变相提高batch_size的方法,对于计算机硬件不行,但是batch_size可能需要设高的领域比较适合,比如目标检测模型的训练。
关于这一点可以参考这里
关于backward()的计算可以参考这里
补充:pytorch 踩坑笔记之w.grad.data.zero_()
在使用pytorch实现多项线性回归中,在grad更新时,每一次运算后都需要将上一次的梯度记录清空,运用如下方法:
w.grad.data.zero_()
b.grad.data.zero_()
但是,运行程序就会报如下错误:
报错,grad没有data这个属性,
原因是,在系统将w的grad值初始化为none,第一次求梯度计算是在none值上进行报错,自然会没有data属性
修改方法:添加一个判断语句,从第二次循环开始执行求导运算
for i in range(100):
y_pred = multi_linear(x_train)
loss = getloss(y_pred,y_train)
if i != 0:
w.grad.data.zero_()
b.grad.data.zero_()
loss.backward()
w.data = w.data - 0.001 * w.grad.data
b.data = b.data - 0.001 * b.grad.data
来源:https://blog.csdn.net/u011959041/article/details/102760868
猜你喜欢
- function geturl($url) { $ch = curl_init(); $timeout = 5; curl_setopt($
- 本文实例讲述了PHP基于phpqrcode类生成二维码的方法。分享给大家供大家参考,具体如下:使用PHP语言生成二维码,还是挺有难度的,当然
- 引文: 长期以来,多媒体信息在计算机中都是以文件形式存放,由操作系统管理的,但是随着计算机网络,分布式计算的发展,对多媒体信息进行高效的管理
- 1、为什么要掌握进程间通信python的多线程代码效率由于受制于GIL,不能利用多核CPU来加速,而多进程方式可以绕过GIL, 发挥多CPU
- 安装conda activate ps pip install visdom激活ps的环境,在指定的ps环境中安装visdom开启pytho
- 这篇文章主要介绍了Python assert关键字原理及实例解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值
- Python提供了一个内联模块buildin,该模块定义了一些软件开发中经常用到的函数,利用这些函数可以实现数据类型的转换、数据的计算、序列
- python实现rsa加密实例详解一 代码import rsakey = rsa.newkeys(3000)#生成随机秘钥privateKe
- 1、设置数据库模式为简单模式:打开SQL企业管理器,在控制台根目录中依次点开Microsoft SQL Server-->SQL Se
- 学校让我们在放假期间自觉Python,对于Python我是小白的不能再小白了。一切从头开始,找学习资料,看视频教程光看书看视频也不行还要自己
- 大家都知道,数据库的安全性是很重要的,它直接影响到数据库的广泛应用。用户可以采用任意一种方法来保护数据库应用程序,也可以将几种方法结合起来使
- 这篇文章主要介绍了python正则表达式匹配IP代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 1. 安装依赖将PyTorch模型转换为ONNX格式可以使它在其他框架中使用,如TensorFlow、Caffe2和MXNet首先安装以下必
- 使用celery在django项目中实现异步发送短信在项目的目录下创建celery_tasks用于保存celery异步任务。在celery_
- 前言博主学习python有个几年了,对于python的掌握越来越深,很多时候,希望自己能掌握python越来越多的知识,但是,也意识很多时候
- 爬虫简介 根据百度百科定义:网络爬虫(又被称为网页蜘蛛,网络机器人,在FOAF社
- url标记为变量通过把 URL 的一部分标记为 <variable_name> 就可以在 URL 中添加变量。标记的 部分会作为
- 可用性研究表明,当响应时间超过一秒钟时,用户便能够有所察觉。虽然在反馈系统中,当用户需要等待时,更好的解决方案的是应该采用确定性的进度条。但
- xlsxwriter 简介用于以 Excel 2007+ XLSX 文件格式编写文件,相较之下 PhpSpreadsheet 支持更多的格式
- 一、初识正则表达式正则表达式 是一个特殊的字符序列,一个字符串是否与我们所设定的这样的字符序列,相匹配快速检索文本、实现替换文本的操作jso