浅谈Pytorch中的自动求导函数backward()所需参数的含义
作者:站着刷论文 发布时间:2021-04-29 13:38:04
正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿。
对标量自动求导
首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的。
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([2,3]),requires_grad=True)
b = a + 3
c = b * 3
out = c.mean()
out.backward()
print('input:')
print(a.data)
print('output:')
print(out.data.item())
print('input gradients are:')
print(a.grad)
运行结果:
不难看出,我们构建了这样的一个函数:
所以其求导也很容易看出:
这是对其进行标量自动求导的结果.
对向量自动求导
如果out.backward()中的out是一个向量(或者理解成1xN的矩阵)的话,我们对向量进行自动求导,看看会发生什么?
先构建这样的一个模型(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有两个输出):
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2
b[0,1] = a[0,1] ** 3
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)
模型也很简单,不难看出out求导出来的雅克比应该是:
因为a1 = 2,a2 = 4,所以上面的矩阵应该是:
运行的结果:
嗯,的确是8和96,但是仔细想一想,和咱们想要的雅克比矩阵的形式也不一样啊。难道是backward自动把0给省略了?
咱们继续试试,这次在上一个模型的基础上进行小修改,如下:
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1]
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)
可以看出这个模型的雅克比应该是:
运行一下:
等等,什么鬼?正常来说不应该是
么?我是谁?我再哪?为什么就给我2个数,而且是 8 + 2 = 10 ,96 + 2 = 98 。难道都是加的 2 ?想一想,刚才咱们backward中传的参数是 [ [ 1 , 1 ] ],难道安装这个关系对应求和了?咱们换个参数来试一试,程序中只更改传入的参数为[ [ 1 , 2 ] ]:
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1]
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,2.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)
嗯,这回可以理解了,我们传入的参数,是对原来模型正常求导出来的雅克比矩阵进行线性操作,可以把我们传进的参数(设为arg)看成一个列向量,那么我们得到的结果就是:
在这个题目中,我们得到的实际是:
看起来一切完美的解释了,但是就在我刚刚打字的一刻,我意识到官方文档中说k.backward()传入的参数应该和k具有相同的维度,所以如果按上述去解释是解释不通的。哪里出问题了呢?
仔细看了一下,原来是这样的:在对雅克比矩阵进行线性操作的时候,应该把我们传进的参数(设为arg)看成一个行向量(不是列向量),那么我们得到的结果就是:
也就是:
这回我们就解释的通了。
现在我们来输出一下雅克比矩阵吧,为了不引起歧义,我们让雅克比矩阵的每个数值都不一样(一开始分析错了就是因为雅克比矩阵中有相同的数据),所以模型小改动如下:
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1]
b[0,1] = a[0,1] ** 3 + a[0,0] * 2
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1,0]]),retain_graph=True)
A_temp = copy.deepcopy(a.grad)
a.grad.zero_()
out.backward(torch.FloatTensor([[0,1]]))
B_temp = a.grad
print('jacobian matrix is:')
print(torch.cat( (A_temp,B_temp),0 ))
如果没问题的话咱们的雅克比矩阵应该是 [ [ 8 , 2 ] , [ 4 , 96 ] ]
来源:https://blog.csdn.net/f156207495/article/details/88727860


猜你喜欢
- TensorFlow中的log共有INFO、WARN、ERROR、FATAL 4种级别。有以下几种设置方式。1. 通过设置环境变量控制log
- Python 运算符通常用于对值和变量执行操作。这些是用于逻辑和算术运算的标准符号。在本文中,我们将研究不同类型的 Python 运算符。&
- 一.目标浏览网页的时候,看见哪个元素,就能截取哪个元素当图片,不管那个元素有多长 二.所用工具和第三方库python ,PIL,s
- 在很多应用程序开发中,需要记录某些数据表的历史记录或修改痕迹,以便日后出现数据错误时进行数据排查。这种业务需求,我们可以通过数据库的触发器来
- springboot配置文件抽离,便于服务器读取对应配置文件,避免项目频繁更改配置文件,影响项目的调试与发布1.创建统一配置中心项目coni
- 在没步入正轨之前,先给大家介绍JavaScript 特殊字符你可以在 JavaScript 中使用反斜杠来向文本字符串添加特殊字符。插入特殊
- isnull在数据库查询中的应用,特别是再语句连接的时候需要用到 比如连接时候,某个字段没有值但是又要左连接到其他表上 就会显示空, isn
- 本文实例讲述了Python设计模式之状态模式原理与用法。分享给大家供大家参考,具体如下:状态模式(State Pattern):当一个对象的
- 日常小程序经常需要分页查询的功能,本篇我们讲解一下低代码中如何实现分页查询的功能。要自己开发分页功能,可以先参考官方的方法分页查询我们一般是
- reduce()函数也是Python内置的一个高阶函数。reduce()格式:reduce (func, seq[, init()])red
- SQL中的单记录函数1.ASCII返回与指定的字符对应的十进制数;SQL> select ascii('A') A,a
- VSCode卸载后进行重新安装,发现新安装的还有原来的一些配置,卸载的不彻底,有时候也容易出问题,可按照如下方法卸载干净:1.进入控制面板卸
- 针对很普遍的每个元素的操作会遍历每个元素进行操作。这里给出了几种写法,列表每个元素自增等数学操作同理;示例:整形列表ilist加1个数、元素
- itchat模块官方参考文档:https://itchat.readthedocs.io/zh/latest/安装pip install i
- 今天要处理通知书上的日期,写的一个处理程序,效率可能不是最优的,不过实现功能绝对没问题。注:月份和天要分>10,=10,<10三
- cmp()方法返回两个数的差的符号: -1 如果 x < y, 0 如果 x == y, 或者 1 如果 x > y
- PyQt5 QtChart-散点图QScatterSeries类将数据以散点图显示import sysimport randomfrom P
- 1、数据驱动介绍:@ddt.ddt(类装饰器,申明当前类使用ddt框架)@ddt.data(函数装饰器,用于给测试用例传递数据),支持传py
- IEEE Spectrum 根据以下数据来源,对各大编程语言的使用普及率进行了统计。1)谷歌搜索结果2)谷歌趋势分析3)推特 (这是什么东西
- 本文实例讲述了Python多进程机制。分享给大家供大家参考。具体如下:在以前只是接触过PYTHON的多线程机制,今天搜了一下多进程,相关文章