PyTorch训练LSTM时loss.backward()报错的解决方案
作者:Ricky_Yan 发布时间:2022-01-10 00:04:09
训练用PyTorch编写的LSTM或RNN时,在loss.backward()上报错:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
千万别改成loss.backward(retain_graph=True),会导致显卡内存随着训练一直增加直到OOM:
RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 10.73 GiB total capacity; 9.79 GiB already allocated; 13.62 MiB free; 162.76 MiB cached)
正确做法:
LSRM / RNN模块初始化时定义好hidden,每次forward都要加上self.hidden = self.init_hidden():
Class LSTMClassifier(nn.Module):
def __init__(self, embedding_dim, hidden_dim):
# 此次省略其它代码
self.rnn_cell = nn.LSTM(embedding_dim, hidden_dim)
self.hidden = self.init_hidden()
# 此次省略其它代码
def init_hidden(self):
# 开始时刻, 没有隐状态
# 关于维度设置的详情,请参考 Pytorch 文档
# 各个维度的含义是 (Seguence, minibatch_size, hidden_dim)
return (torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim))
def forward(self, x):
# 此次省略其它代码
self.hidden = self.init_hidden() # 就是加上这句!!!!
out, self.hidden = self.rnn_cell(x, self.hidden)
# 此次省略其它代码
return out
或者其它模块每次调用这个模块时,其它模块的forward()都对这个LSTM模块init_hidden()一下。
如定义一个模型LSTM_Model():
Class LSTM_Model(nn.Module):
def __init__(self, embedding_dim, hidden_dim):
# 此次省略其它代码
self.rnn = LSTMClassifier(embedding_dim, hidden_dim)
# 此次省略其它代码
def forward(self, x):
# 此次省略其它代码
self.rnn.hidden = self.rnn.init_hidden() # 就是加上这句!!!!
out = self.rnn(x)
# 此次省略其它代码
return out
这是因为:
根据 官方tutorial,在 loss 反向传播的时候,pytorch 试图把 hidden state 也反向传播,但是在新的一轮 batch 的时候 hidden state 已经被内存释放了,所以需要每个 batch 重新 init (clean out hidden state), 或者 detach,从而切断反向传播。
补充:pytorch:在执行loss.backward()时out of memory报错
在自己编写SurfNet网络的过程中,出现了这个问题,查阅资料后,将得到的解决方法汇总如下
可试用的方法:
1、reduce batch size, all the way down to 1
2、remove everything to CPU leaving only the network on the GPU
3、remove validation code, and only executing the training code
4、reduce the size of the network (I reduced it significantly: details below)
5、I tried scaling the magnitude of the loss that is backpropagating as well to a much smaller value
在训练时,在每一个step后面加上:
torch.cuda.empty_cache()
在每一个验证时的step之后加上代码:
with torch.no_grad()
不要在循环训练中累积历史记录
total_loss = 0
for i in range(10000):
optimizer.zero_grad()
output = model(input)
loss = criterion(output)
loss.backward()
optimizer.step()
total_loss += loss
total_loss在循环中进行了累计,因为loss是一个具有autograd历史的可微变量。你可以通过编写total_loss += float(loss)来解决这个问题。
本人遇到这个问题的原因是,自己构建的模型输入到全连接层中的特征图拉伸为1维向量时太大导致的,加入pool层或者其他方法将最后的卷积层输出的特征图尺寸减小即可。
来源:https://blog.csdn.net/qq_31375855/article/details/107568057


猜你喜欢
- Hadoop 命令行最常用指令篇:1.ls (list directory)Usage:hadoop fs -ls [R]Option: -
- Django crontab定时任务安装pip install django-crontab配置在settings.py中 INSTALLE
- Prometheus是什么Prometheus是一套开源监控系统和告警为一体,由go语言(golang)开发,是监控+报警+时间序列数据库的
- 楔子上一篇文章我们探讨了 GIL 的原理,以及如何释放 GIL 实现并行,做法是将函数声明为 nogil,然后使用 with nogil 上
- read()方法读取文件size个字节大小。如果读取命中获得EOF大小字节之前,那么它只能读取可用的字节。语法以下是read()
- import httplibimport osimport timedef check_http(i):
- 如下所示:<!doctype html><html><head><meta charset=&qu
- 一、准备工程文件1.创建工程leeoo2.在工程根目录下创建setup.py文件3.在工程根目录下创建同名package二、编辑setup.
- 简介:1.霍夫变换(Hough Transform) 霍夫变换是图像处理中从图像中识别几何形状的基本方法之一,应用很广泛,也有很多改进算法。
- CentOS6.9安装Mysql5.7,供大家参考,具体内容如下一、上传安装包二、建立用户以及mysql的目录1、建立一个mysql的组输入
- 大家知道,mailto是网页设计制作中的一个非常实用的html标签,许多拥有个人网页的朋友都喜欢在网站的醒目位置处写上自己的电子邮件地址,这
- 现在网上有很多python2写的爬虫抓取网页图片的实例,但不适用新手(新手都使用python3环境,不兼容python2),所以我用Pyth
- 在数据库中,UNION和UNION ALL关键字都是将两个结果集合并为一个,但这两者从使用和效率上来说都有所不同。MySQL中的UNIONU
- 如何做一个树状展开视图来显示自己的记录结构?在SQL中,如何做一个可收起和展开树状结构图?就是资源管理器左栏的那种效果。这要用到Data s
- python类class定义及其初始化定义类,功能,属性一般类名首字母大写class Calculator:#名字和价格是属性
- 1.若有疑问立即检测 在出错时若能对原始代码做简单检测可以省去很多头痛问题。W3C对于XHTML与CSS 都有检测工具可用,请见 http:
- 本文实例讲述了python解析xml文件操作的实现方法。分享给大家供大家参考。具体方法如下:xml文件内容如下:<?xml versi
- 一,extract方法的使用extract函数主要是对于数据进行提取。场景一般对于DataFrame中的一列中的数据进行提取的场合比较多。例
- Web_THBC 为表示层也就是页面(.aspx) BLL_THBC 为业务逻辑层 DAL_THBC 为数据库交互层 (向数据库执行SQL语
- 在HTML中,我们设置border=”1″ 时,表格边框实际大小是2px,那如果我们要做成1px的细线表格要怎么办?以前在做1px的表格的时