pytorch训练时的显存占用递增的问题解决
作者:来包番茄沙司 发布时间:2021-04-20 07:12:45
遇到的问题:
在pytorch训练过程中突然out of memory。
解决方法:
1. 测试的时候爆显存有可能是忘记设置no_grad
加入 with torch.no_grad()
model.eval()
with torch.no_grad():
for idx, (data, target) in enumerate(data_loader):
if args.gpu != -1:
data, target = data.to(args.device), target.to(args.device)
log_probs = net_g(data)
probs.append(log_probs)
# sum up batch loss
test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
# get the index of the max log-probability
y_pred = log_probs.data.max(1, keepdim=True)[1]
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
2. loss.item()
写成loss_train = loss_train + loss.item(),不能直接写loss_train = loss_train + loss
3. 在代码中添加以下两行:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
4. del操作后再加上torch.cuda.empty_cache()
单独使用del、torch.cuda.empty_cache()效果都不明显,因为empty_cache()不会释放还被占用的内存。
所以这里使用了del让对应数据成为“没标签”的垃圾,之后这些垃圾所占的空间就会被empty_cache()回收。
"""添加了最后两行,img和segm是图像和标签输入,很明显通过.cuda()已经是被存在在显存里了;
outputs是模型的输出,模型在显存里当然其输出也在显存里;loss是通过在显存里的segm和
outputs算出来的,其也在显存里。这4个对象都是一次性的,使用后应及时把其从显存中清除
(当然如果你显存够大也可以忽略)。"""
def train(model, data_loader, batch_size, optimizer):
model.train()
total_loss = 0
accumulated_steps = 32 // batch_size
optimizer.zero_grad()
for idx, (img, segm) in enumerate(tqdm(data_loader)):
img = img.cuda()
segm = segm.cuda()
outputs = model(img)
loss = criterion(outputs, segm)
(loss/accumulated_steps).backward()
if (idx + 1 ) % accumulated_steps == 0:
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
# delete caches
del img, segm, outputs, loss
torch.cuda.empty_cache()
补充:Pytorch显存不断增长问题的解决思路
思路很简单,就是在代码的运行阶段输出显存占用量,观察在哪一块存在显存剧烈增加或者显存异常变化的情况。
但是在这个过程中要分级确认问题点,也即如果存在三个文件main.py、train.py、model.py。
在此种思路下,应该先在main.py中确定问题点,然后,从main.py中进入到train.py中,再次输出显存占用量,确定问题点在哪。
随后,再从train.py中的问题点,进入到model.py中,再次确认。
如果还有更深层次的调用,可以继续追溯下去。
例如:
main.py
def train(model,epochs,data):
for e in range(epochs):
print("1:{}".format(torch.cuda.memory_allocated(0)))
train_epoch(model,data)
print("2:{}".format(torch.cuda.memory_allocated(0)))
eval(model,data)
print("3:{}".format(torch.cuda.memory_allocated(0)))
若1与2之间显存增加极为剧烈,说明问题出在train_epoch中,进一步进入到train.py中。
train.py
def train_epoch(model,data):
model.train()
optim=torch.optimizer()
for batch_data in data:
print("1:{}".format(torch.cuda.memory_allocated(0)))
output=model(batch_data)
print("2:{}".format(torch.cuda.memory_allocated(0)))
loss=loss(output,data.target)
print("3:{}".format(torch.cuda.memory_allocated(0)))
optim.zero_grad()
print("4:{}".format(torch.cuda.memory_allocated(0)))
loss.backward()
print("5:{}".format(torch.cuda.memory_allocated(0)))
utils.func(model)
print("6:{}".format(torch.cuda.memory_allocated(0)))
如果在1,2之间,5,6之间同时出现显存增加异常的情况。此时需要使用控制变量法,例如我们先让5,6之间的代码失效,然后运行,观察是否仍然存在显存 * 。如果没有,说明问题就出在5,6之间下一级的代码中。进入到下一级代码,进行调试:
utils.py
def func(model):
print("1:{}".format(torch.cuda.memory_allocated(0)))
a=f1(model)
print("2:{}".format(torch.cuda.memory_allocated(0)))
b=f2(a)
print("3:{}".format(torch.cuda.memory_allocated(0)))
c=f3(b)
print("4:{}".format(torch.cuda.memory_allocated(0)))
d=f4(c)
print("5:{}".format(torch.cuda.memory_allocated(0)))
此时我们再展示另一种调试思路,先注释第5行之后的代码,观察显存是否存在先训 * ,如果没有,则注释掉第7行之后的,直至确定哪一行的代码出现导致了显存 * 。假设第9行起作用后,代码出现显存 * ,说明问题出在第九行,显存 * 的问题锁定。
参考链接:
http://www.zzvips.com/article/196059.html
https://blog.csdn.net/fish_like_apple/article/details/101448551
来源:https://blog.csdn.net/weixin_45928096/article/details/128691564


猜你喜欢
- 由于图片水印的种类有很多,今天我们先讲最简单的一种。即上图中的①类水印,这种水印存在白色背景上的文档里,水印是灰色,需要保留的文字是黑色。这
- 目录:分析和设计组件编码实现和算法用 Ant 构建组件测试 JavaScript 组件话说上期我们讨论了队列管理组件的设计,并且给它取了个响
- 线上有个需求,格式化,从一堆s1,s100-s199中找出连续的服并且格式化显示出来,如:神魔:S106-109,s123,s125御剑:
- 错误:ImportError: libcublas.so.9.0: cannot open shared object file: No s
- 循环是我们经常用到的一个概念,比如,循环计算数字叠加、循环输出文字内容等。循环是运行重复内容的一个最简单的方法,简化了代码流程,增加了时效性
- sql server端口,我们可以通过\"服务器端网络试用工具\"和\"客户端实用工具\"来设定,设
- 一 NULL 为什么这么经常用(1) java的nullnull是一个让人头疼的问题,比如java中的NullPointerExceptio
- import React, { Component } from 'react';import { Table, Input
- 之前折磨了很久,想在Mysql命令行下导出数据库,但就是每天提示不那个错误,后来才知道其实mysqldump不是mysql命令,因此不能在M
- 今天学习CI框架过程中遇到个问题: A PHP Error was encountered Severity: Notice Message
- 项目环境:python3.6一、项目结构二、数据集准备数据集准备分为两步:获取图片.提取人脸.1、获取图片首先可以利用爬虫,从百度图片上批量
- 这是一个通过js实现的支付后的页面,点击支付会跳出一个弹窗,提示你是否要确定支付,确定后进入付后界面,该页面有着10秒倒计时,计时结束后便会
- 很久都没写 Flask 代码相关了,想想也真是惭愧,然并卵,这次还是不写 Flask 相关,不服你来打我啊(就这么贱,有本事咬我啊这次我来写
- 我想要的结果无非是去掉URL路径中的index.php首先是配置.htaccess<IfModule mod_rewrite.c>
- python具体强大的库文件,很多功能都有相应的库文件,所以很有必要进行学习一下,其中有一个ftp相应的库文件ftplib,我们只需要其中的
- MySQL查询字段为空或者为null判断为nullselect * from table where column is nul
- 定义和用法strftime() 函数根据区域设置格式化本地时间/日期。语法strftime(format,timestamp)参数 描述 f
- python2.7安装opencv-python很慢且总是失败当直接使用pip安装opencv-python时,且总是报错,找了好久,发现是
- 第一章:基本的圆角框第二章:透明圆角化背景图片第三章:圆角化图片 第四章:CSS圆角框组件 V1.0在上面的案例中,我只给出最为原始的圆角框
- 这两天在用python的bottle框架开发后台管理系统,接口约定使用RESTful风格请求,前端使用jquery ajax与接口进行交互,