在pytorch中对非叶节点的变量计算梯度实例
作者:FesianXu 发布时间:2021-08-26 10:13:53
标签:pytorch,节点,变量,梯度
在pytorch中一般只对叶节点进行梯度计算,也就是下图中的d,e节点,而对非叶节点,也即是c,b节点则没有显式地去保留其中间计算过程中的梯度(因为一般来说只有叶节点才需要去更新),这样可以节省很大部分的显存,但是在调试过程中,有时候我们需要对中间变量梯度进行监控,以确保网络的有效性,这个时候我们需要打印出非叶节点的梯度,为了实现这个目的,我们可以通过两种手段进行。
注册hook函数
Tensor.register_hook[2] 可以注册一个反向梯度传导时的hook函数,这个hook函数将会在每次计算 关于该张量 的时候 被调用,经常用于调试的时候打印出非叶节点梯度。当然,通过这个手段,你也可以自定义某一层的梯度更新方法。[3] 具体到这里的打印非叶节点的梯度,代码如:
def hook_y(grad):
print(grad)
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3
y.register_hook(hook_y)
out = z.mean()
out.backward()
输出如:
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
retain_grad()
Tensor.retain_grad()显式地保存非叶节点的梯度,当然代价就是会增加显存的消耗,而用hook函数的方法则是在反向计算时直接打印,因此不会增加显存消耗,但是使用起来retain_grad()要比hook函数方便一些。代码如:
x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)
输出如:
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
来源:https://blog.csdn.net/LoseInVain/article/details/99172594


猜你喜欢
- MySQL 当记录不存在时插入(insert if not exists) 在 MySQL 中,插入(insert)一条记录很简单,但是一些
- 在装这两个的时候出现一些问题,最后总算成功了,记录一下过程环境:win10 64位系统,python3.7.8 ,pip18下载地址:这两个
- 本文实例讲述了Python使用sklearn实现的各种回归算法。分享给大家供大家参考,具体如下:使用sklearn做各种回归基本回归:线性、
- 一个Javascript 的类库,用于table内容排序。使用很方便,不用每次都去调用数据库了。特别适合多表查询的排序。加上<tbod
- 类似如下: select A.key,B.key,C.key from A,B,C where trim(A.key)=trim(B.fk)
- 1.el-input无法输入的问题原因1、el-input组件没有绑定双向响应式数据(v-model)解决方案:在data中定义一个变量,然
- 前言近期在刷新生产环境数据库的时候,需要更新表中的字段,如果对每条数据结果都执行一次update语句,占用的数据库资源就会很多,而且速度慢。
- 唉,可怜呀,用了这么久的SQL今天头一次用到外连接,效果不错,方法如下: 使用外联接 仅当至少有一个同属于两表的行符合联接条件时,内联接才返
- Google中秋的logo出来了,酷似一美男站在月亮上,结果被网友弄出一撒尿版来。中国网民好智慧啊~原logo: 撒尿版logo:
- 当数据量猛增的时候,大家都会选择库表散列等等方式去优化数据读写速度。笔者做了一个简单的尝试,1亿条数据,分100张表。具体实现过程如下。首先
- 昨天在W3C看到,6月10日发布了新的 HTML 5 草案(Working Draft)。粗略的读了一下它提供的 新版本说明文档 ,作了一点
- Microsoft SQL Server Management Studio是SQL SERVER的客户端工具,相信大家都知道。我不知道大伙
- 使用MySQL,目前你可以在三种基本数据库表格式间选择。当你创建一张表时,你可以告诉MySQL它应该对于表使用哪个表类型。MySQL将总是创
- FlashPaper 是Macromedia推出的一款电子文档类工具,通过使用本程序,你可以将需要的文档通过简单的设置转换为SWF格式的Fl
- tkinter禁用(只读)下拉列表Comboboxtkinter将下拉列表框Combobox控件的状态设置为只读,也就是不可编辑状态:# 定
- 网页过渡是指当浏览者进入或离开网页时,页面呈现的不同的刷新效果,比如卷动、百叶窗等。这样你的网页看起来
- 描述的意思是HTML为中心的前端开发也差不多是web标准的意思。1.HTML是基础2.CSS依靠选择符提供视觉;3.Javascript 依
- 如何创建列表,或生成列表。这里介绍在python的基础知识里创建或转变或生成列表的一些方法。零个,一个或一系列数据用逗号隔开,放在方括号[
- php实现上传图片保存到数据库的方法。分享给大家供大家参考。具体分析如下:php 上传图片,一般都使用move_uploaded_file方
- 在Oracle 8i版本之前,使用internal用户来执行数据库的启动和关闭以及create database等操作;从8i版本以后,Or