Pytorch中retain_graph的坑及解决
作者:Longlongaaago 发布时间:2022-12-20 16:21:09
Pytorch中retain_graph的坑
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True) #####
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
optimizerG.step()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out, fake_img, real_img)
running_results['g_loss'] += g_loss.data[0] * batch_size
d_loss = 1 - real_out + fake_out
running_results['d_loss'] += d_loss.data[0] * batch_size
running_results['d_score'] += real_out.data[0] * batch_size
running_results['g_score'] += fake_out.data[0] * batch_size
也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True) 让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。
但是另外一个问题在于,如果你都这么用的话,显存会 * ,因为他保留了梯度,所以都没有及时释放掉,浪费资源。
而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!
Pytorch中有多次backward时需要retain_graph参数
Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中可能有多次backward()时,因为前一次调用backward()时已经释放掉了buffer,所以下一次调用时会因为buffers不存在而报错
解决办法
loss.backward(retain_graph=True)
错误使用
optimizer.zero_grad()
清空过往梯度;loss1.backward(retain_graph=True)
反向传播,计算当前梯度;loss2.backward(retain_graph=True)
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数
因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)
正确使用
optimizer.zero_grad()
清空过往梯度;loss1.backward(retain_graph=True)
反向传播,计算当前梯度;loss2.backward()
反向传播,计算当前梯度;optimizer.step()
根据梯度更新网络参数
最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了
来源:https://blog.csdn.net/Willen_/article/details/89394766
猜你喜欢
- 本文实例为大家分享了python实现文字版扫雷的具体代码,供大家参考,具体内容如下python版本:2.7游戏运行图:代码已经注释得很清楚,
- WEB标准,从我大二开始接触到毕业后的第一份工作“页面重构工程师”,从接触标准到蓝色理想标准区版主的四年多时间里,WEB标准已经成为我生活中
- 见下:<% FOR i = 1 TO 1000 n =
- 模糊数据库指能够处理模糊数据的数据库。一般的数据库都是以二直逻辑和精确的数据工具为基础的,不能表示许多模糊不清的事情。随着模糊数学理论体系的
- 前言在进行业务数据分析时,往往需要使用pandas计算环比、同比及增长率等指标,为了能够更加方便的进行的统计数据,整理方法如下。1.数据准备
- ansible 简介ansible 是什么?ansible是新出现的自动化运维工具,基于Python开发,集合了众多运维工具(puppet、
- 本文借鉴于张广河教授主编的《数据结构》,对其中的代码进行了完善。从某源点到其余各顶点的最短路径Dijkstra算法可用于求解图中某源点到其余
- MySQL是一个小型关系型数据库管理系统,开发者为瑞典MySQLAB公司,在2008年1月16号被Sun公司收购。MySQL被广泛地应用在I
- 图片人脸识别import cv2filepath = "img/xingye-1.png"img = cv2.imrea
- 1.ROOT_URLCONF = '总路由所在路径(比如untitled.urls)'<===默认情况是这样根路由的路
- <!--#include file="config.asp" -->&nbs
- python结构体数组在C语言中我们可以通过struct关键字定义结构类型,结构中的字段占据连续的内存空间,每个结构体占用的内存大小都相同,
- 先来了解一下收/发邮件有哪些协议:SMTP协议 SMTP(Simple Mail Transfer Protocol),即简单邮件传输协议。
- 在WEB2.0这个词未出现之前,是没有所谓的WEB1.0之说的,那时候的互联网也是没有时代之分的,能上的网站不多,值得上的网站更不多,很多的
- 前言Python 的一大优点就是丰富的类库,所以我们经常会用 pip 来安装各种库,所以对于Python开发用户来讲,PIP安装软件包是家常
- Oblog4.6 ACCESS版转换为UCenterHome1.5的全过程1、 说明:
- python提取特定时间段内的数据尝试一下:data['Date'] = pd.to_datetime(data['
- 函数: # 什么是函数:一系列python语句的组合,可以在程序中运行一次或者多次# 一般是完成具体的独立的功能# 为什么要使用函数# 代码
- 前言:近我使用 Go 语言完成了一个正式的 Web 应用,有一些方面的问题在使用 Go 开发 Web 应用过程中比较重要。过去,我将 Web
- 简介python 动态执行字符串代码片段(也可以是文件), 一般会用到exec,eval。execexec_stmt ::= "e