Pytorch中retain_graph的坑及解决
作者:Longlongaaago 发布时间:2022-12-20 16:21:09
Pytorch中retain_graph的坑
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True) #####
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
optimizerG.step()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out, fake_img, real_img)
running_results['g_loss'] += g_loss.data[0] * batch_size
d_loss = 1 - real_out + fake_out
running_results['d_loss'] += d_loss.data[0] * batch_size
running_results['d_score'] += real_out.data[0] * batch_size
running_results['g_score'] += fake_out.data[0] * batch_size
也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True) 让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。
但是另外一个问题在于,如果你都这么用的话,显存会 * ,因为他保留了梯度,所以都没有及时释放掉,浪费资源。
而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!
Pytorch中有多次backward时需要retain_graph参数
Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中可能有多次backward()时,因为前一次调用backward()时已经释放掉了buffer,所以下一次调用时会因为buffers不存在而报错
解决办法
loss.backward(retain_graph=True)
错误使用
optimizer.zero_grad()
清空过往梯度;loss1.backward(retain_graph=True)
反向传播,计算当前梯度;loss2.backward(retain_graph=True)
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数
因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)
正确使用
optimizer.zero_grad()
清空过往梯度;loss1.backward(retain_graph=True)
反向传播,计算当前梯度;loss2.backward()
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数
最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了
来源:https://blog.csdn.net/Willen_/article/details/89394766


猜你喜欢
- 编写 models.py 文件from django.db import models# Create your models here.c
- 本文介绍了计算多个订单的核销金额的全部过程,运行数据库环境:SQL SERVER 2005,下面跟大家分享一下。下图是一张订单明细表,现有金
- 在ASP中,除了ADODB、Scripting 等一些常用组件外,我们还可以用微软的ActiveX方法来轻松捕获哟: <%u
- 如下所示:'''Created on 2018-4-20例子:每天凌晨3点执行func方法''
- 下面是调用方式:Example script - pymssql module (DB API 2.0) Example script -
- 本文实例讲述了使用Flask-Cache缓存实现给Flask提速的方法。分享给大家供大家参考,具体如下:Django里面可以很方便的应用缓存
- asp编程手工定义参数的方法: Dim con As ADODB.Connection
- 一、功能需求1.根据输入内容进行模糊查询,选择地址后在地图上插上标记,并更新经纬度坐标显示2.在地图点击后,根据回传的左边更新地址信息和坐标
- Updates(2019.8.14 19:53)吃饭前用这个方法实战了一下,吃完回来一看好像不太行:跑完一组参数之后,到跑下一组参数时好像没
- 目录一,python介绍二.python的安装程序三、变量python基础部分学习一,python介绍python的创始人为吉多·范罗苏姆(
- 这篇文章主要介绍了如何通过python实现全排列,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以
- 首先给出展示结果,大体就是检测工业板子是否出现。采取检测的方法比较简单,用的OpenCV的模板检测。大体思路opencv读取视频将视频分割为
- 上一次写的《Bootstrap编写一个兼容主流浏览器的受众巨幕式风格页面》(点击打开链接)部分老一辈的需求可能对这种后现代的风格并不满意,没
- python爬取数据保存为Json格式代码如下:#encoding:'utf-8'import urllib.request
- 存储和读取ASCII码形式的byte数据Python可以存byte数据到txt,但不要用str的方式直接存,转成数字列表储存,这样方便读取L
- 下面介绍在Linux上利用python获取本机ip的方法.经过网上调查, 发现大致有两种方法, 一种是调用shell脚本,另一种是利用pyt
- 最近在重构公司以前产品的前端代码,摈弃了以前的session-cookie鉴权方式,采用token鉴权,忙里偷闲觉得有必要对几种常见的鉴权方
- sql="select * from admin where users='"&users&&q
- 打包下载Pain.php <?php class Pain { public $var=array(); public $tpl=ar
- 代码: <input type="text" value="fisker" onclick=&