弄清Pytorch显存的分配机制
作者:颀周 发布时间:2023-05-22 22:12:44
标签:Pytorch,显存,分配
对于显存不充足的炼丹研究者来说,弄清楚Pytorch显存的分配机制是很有必要的。下面直接通过实验来推出Pytorch显存的分配过程。
实验实验代码如下:
import torch
from torch import cuda
x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda')
print("1", cuda.memory_allocated()/1024**2)
y = 5 * x
print("2", cuda.memory_allocated()/1024**2)
torch.mean(y).backward()
print("3", cuda.memory_allocated()/1024**2)
print(cuda.memory_summary())
输出如下:
代码首先分配3GB的显存创建变量x,然后计算y,再用y进行反向传播。可以看到,创建x后与计算y后分别占显存3GB与6GB,这是合理的。另外,后面通过backward(),计算出x.grad,占存与x一致,所以最终一共占有显存9GB,这也是合理的。但是,输出显示了显存的峰值为12GB,这多出的3GB是怎么来的呢?首先画出计算图:
下面通过列表的形式来模拟Pytorch在运算时分配显存的过程:
如上所示,由于需要保存反向传播以前所有前向传播的中间变量,所以有了12GB的峰值占存。
我们可以不存储计算图中的非叶子结点,达到节省显存的目的,即可以把上面的代码中的y=5*x与mean(y)写成一步:
import torch
from torch import cuda
x = torch.zeros([3,1024,1024,256],requires_grad=True,device='cuda')
print("1", cuda.memory_allocated()/1024**2)
torch.mean(5*x).backward()
print("2", cuda.memory_allocated()/1024**2)
print(cuda.memory_summary())
占显存量减少了3GB:
来源:https://www.cnblogs.com/qizhou/p/14110086.html


猜你喜欢
- 用for循环实现1~n求和的方法def main(): sum = 0 n = int(input('n=&
- 本文实例讲述了Python实现二叉树及遍历方法。分享给大家供大家参考,具体如下:介绍:树是数据结构中非常重要的一种,主要的用途是用来提高查找
- 背景在一次进行SQl查询时,我试着对where条件中vachar类型的字段去掉单引号查询,这个时候发现这条本应该很快的语句竟然很慢。这个va
- 代码如下:<% '隐藏并修改文件的最后修改时间的aspshell '原理:通过FSO可以修改文件的
- 可怜我的C盘本来只有8.XG,所以不得不卸载掉它。卸载掉本身没啥问题,只是昨晚突然发现 Sql Server 2008 R2 Managem
- 在Windows平台下,如果想运行爬虫的话,就需要在cmd中输入:scrapy crawl spider_name这时,爬虫就能启动,并在控
- GROUP BY 语句用于结合合计函数,根据一个或多个列对结果集进行分组。1、概述“Group By”从字面意义上理解就是根据“By”指定的
- 参考Tensorflow Machine Leanrning Cookbooktf.ConfigProto()主要的作用是配置tf.Sess
- 网页的布局也许是大家最不放在眼里的地方,其实布局地位如同文字的排版一样,随便可布,布即随便。但是看过我上篇《网页设计技巧系列 之 文本排版》
- python 利用pywifi模块实现连接网络破解wifi密码实时监控网络,具体内容如下:import pywififrom pywifi
- 因为项目开发中遇到需要向后台传本周的开始和结束时间,以及上一周的起止时间,就琢磨了半天,总算写出来一套,写篇文章是为了方便自己记忆,也是分享
- 表 table1 id RegName PostionSN PersonSN 1 山东齐鲁制药 223 2 2 山东齐鲁制药 224 2 3
- 前言本文分析了 mysqld 进程关闭的过程,以及如何安全、缓和地关闭 MySQL 实例,对这个过程不甚清楚的同学可以参考下。关闭过程1、发
- 在GIS中,栅格属性里有关于栅格自身的信息,背景(nodata value)对于识别一张图像的边界像元尤为重要,我们目的只要把每行每列中的第
- 使用python批量修改文本文件编码格式把文本文件的编码格式进行批量幻化,比如ascii, gb2312, utf8等,相互转化,字符集的大
- 我就废话不多说了,大家还是直接看代码吧!a1 = raw_input("please input a number")a
- 本文实例讲述了php基于curl实现随机ip地址抓取内容的方法。分享给大家供大家参考,具体如下:使用php curl 我们可以模仿用户行为,
- 1. 笛卡尔乘积表1有m行数据,表2有n行数据,查询结果有m*n行数据。2. 分类(1)按年代分类sql92标准:仅支持内连接sql99标准
- 一、数据库备份种类按照数据库大小备份,有四种类型,分别应用于不同场合,下面简要介绍一下:1.1完全备份这是大多数人常用的方式,它可以备份整个
- 如下所示:Description:将一个矩阵(二维数组)按对角线向右进行打印。(搜了一下发现好像是美团某次面试要求半小时手撕的题)Examp