pytorch中retain_graph==True的作用说明
作者:撒旦即可 发布时间:2021-08-03 09:15:26
pytorch retain_graph==True的作用说明
总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward。
retain_graph参数的作用
官方定义:
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.
大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。
但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。
具体看一个例子理解
假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。
然后我们对两个output执行backward。
import torch
x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
y = x ** 2
z = y * 4
print(x)
print(y)
print(z)
loss1 = z.mean()
loss2 = z.sum()
print(loss1,loss2)
loss1.backward() # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
print(loss1,loss2)
loss2.backward() # 这时会引发错误
程序正常执行到第12行,所有的变量正常保存。
但是在第13行报错:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
分析:计算节点数值保存了,但是计算图x-y-z结构被释放了,而计算loss2的backward仍然试图利用x-y-z的结构,因此会报错。
因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。
正确的代码应当把第11行以及之后改成
1 # 假如你需要执行两次backward,先执行第一个的backward,再执行第二个backward
2 loss1.backward(retain_graph=True)# 这里参数表明保留backward后的中间参数。
3 loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
4 #如果是在训练网络optimizer.step() # 更新参数
create_graph参数比较简单,参考官方定义:
create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.
Pytorch retain_graph=True错误信息
(Pytorch:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time)
具有多个loss值
retain_graph设置True,一般多用于两次backward
# 假如有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True) # 这样计算图就不会立即释放
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数
retain_graph设置True后一定要知道释放,否则显卡会占用越来越多,代码速度也会跑的越来越慢。
有的时候我明明仅有一个模型的也会出现这种错误
第一种是输入的原因。
// Example
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
x_train, y_train = x[:70], y[:70]
x_val, y_val = x[70:], y[70:]
for epoch in range(n_epochs):
...
prediction = model(x_train)
loss.backward()
...
在多次循环的过程中,input的梯度没有清除,而且我们也不需要计算输入的梯度,因此将x的require_grad设置为False就可以解决问题。
第二种是我在训练LSTM时候发现的。
class LSTMpred(nn.Module):
def __init__(self, input_size, hidden_dim):
self.hidden = self.init_hidden()
...
def init_hidden(self): #这里我们是需要个隐层参数的
return (torch.zeros(1, 1, self.hidden_dim, requires_grad=True),
torch.zeros(1, 1, self.hidden_dim, requires_grad=True))
def forward(self, seq):
...
这里面的self.hidden我们在每一次训练的时候都要重新初始化隐层参数:
for epoch in range(Epoch):
...
model.hidden = model.init_hidden()
modout = model(seq)
...
3. 我的看法
其实,想想这几种情况都是一回事,都是网络在反向传播中不允许多个backward(),也就是梯度下降反馈的时候,有多个循环过程中共用了同一个需要计算梯度的变量,在前一个循环清除梯度后,后面一个循环过程就会在这个变量上栽跟头(个人想法)。
来源:https://blog.csdn.net/qq_39861441/article/details/104129368
猜你喜欢
- 保持良好的代码风格是每个Coder必学的课程,同样在HTML设计的时候也要特别注意代码的规范性,虽然说不规范的代码不会直接造成严重的后果,但
- 给内存和cpu使用量设置限制在linux系统中,使用Python对内存和cpu使用量设置限制需要通过resource模块来完成。resour
- 作为一个标准的程序猿,为程序编写说明文档是一步必不可少的工作,如何才能写的又好又快呢,下面我们就来详细探讨下吧。今天将告诉大家一个简单平时只
- 目录1)连接请求的变量1、max_connections2、back_log3、wait_timeout和interative_timeou
- 需要画框取消注释rectangleimport cv2import os,sys,shutilimport numpy as np# Ope
- 一、进程介绍进程:正在执行的程序,由程序、数据和进程控制块组成,是正在执行的程序,程序的一次执行过程,是资源调度的基本单位。程序:没有执行的
- 前言docopt 是一个开源的库,代码地址:https://github.com/docopt/docopt。它在 README 中就已经做
- PyAutoGUI是一个纯Python的GUI自动化工具,其目的是可以用程序自动控制鼠标和键盘操作,利用它可以实现自动化任务本章介绍了许多不
- 前言:👉对于新手来说,库的安装是遇到的第一个挑战,我也入了很多坑,所以想出一期安装库的步骤,由于博主水平限制,博客难免会有错误和不准之处,我
- 前几天在一本书上看到一篇可以利用字典破解zip文件密码的文章,觉得比较有意思于是研究了一番,在这里分享一下原理主要是利用python里自带的
- 函数:string.join()Python中有join()和os.path.join()两个函数,具体作用如下:join(): 连接字符串
- Python中的模块(.py文件)在创建之初会自动加载一些内建变量,__name__就是其中之一。Python模块中通常会定义很多变量和函数
- sort包简介官方文档Golang的sort包用来排序,二分查找等操作。本文主要介绍sort包里常用的函数,通过实例代码来快速学会使用sor
- 1、流程控制流程控制在编程语言中是最伟大的发明了,因为有了它,你可以通过很简单的流程描述来表达很复杂的逻辑。流程控制包含分三大类:条件判断,
- 一、分块查找算法分块查找是二分法查找和顺序查找的改进方法,分块查找要求索引表是有序的,对块内结点没有排序要求,块内结点可以是有序的也可以是无
- 1、update delete insert 这种语句都需要commit或者直接在连接数据库的时候加上autocommit=Trueimpo
- 本文实例讲述了Django框架基础模板标签与filter使用方法。分享给大家供大家参考,具体如下:一、基本的模板语言1、变量{{ }}1.1
- Hpack 是啥Hpack 是 HTTP2 的头部压缩算法。在 HTTP1 中,每次传输都会有大量的 Header 携带,我们可以拿一个实际
- 方法一:定义一个函数,参数为所要生成随机字符串的长度。通过random.randint(a, b)方法得到随机数字,具体函数如下:def g
- 在进行CSS网页布局的时候,我们经遇到刷新要保留表单里内容的时候,习惯的做法使用cookie,但是那样做实在是很麻烦,css中的behavi