浅谈Pytorch中的自动求导函数backward()所需参数的含义
作者:站着刷论文 发布时间:2021-04-29 13:38:04
正常来说backward( )函数是要传入参数的,一直没弄明白backward需要传入的参数具体含义,但是没关系,生命在与折腾,咱们来折腾一下,嘿嘿。
对标量自动求导
首先,如果out.backward()中的out是一个标量的话(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有一个输出)那么此时我的backward函数是不需要输入任何参数的。
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([2,3]),requires_grad=True)
b = a + 3
c = b * 3
out = c.mean()
out.backward()
print('input:')
print(a.data)
print('output:')
print(out.data.item())
print('input gradients are:')
print(a.grad)
运行结果:
不难看出,我们构建了这样的一个函数:
所以其求导也很容易看出:
这是对其进行标量自动求导的结果.
对向量自动求导
如果out.backward()中的out是一个向量(或者理解成1xN的矩阵)的话,我们对向量进行自动求导,看看会发生什么?
先构建这样的一个模型(相当于一个神经网络有一个样本,这个样本有两个属性,神经网络有两个输出):
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2
b[0,1] = a[0,1] ** 3
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)
模型也很简单,不难看出out求导出来的雅克比应该是:
因为a1 = 2,a2 = 4,所以上面的矩阵应该是:
运行的结果:
嗯,的确是8和96,但是仔细想一想,和咱们想要的雅克比矩阵的形式也不一样啊。难道是backward自动把0给省略了?
咱们继续试试,这次在上一个模型的基础上进行小修改,如下:
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1]
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,1.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)
可以看出这个模型的雅克比应该是:
运行一下:
等等,什么鬼?正常来说不应该是
么?我是谁?我再哪?为什么就给我2个数,而且是 8 + 2 = 10 ,96 + 2 = 98 。难道都是加的 2 ?想一想,刚才咱们backward中传的参数是 [ [ 1 , 1 ] ],难道安装这个关系对应求和了?咱们换个参数来试一试,程序中只更改传入的参数为[ [ 1 , 2 ] ]:
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1]
b[0,1] = a[0,1] ** 3 + a[0,0]
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1.,2.]]))
print('input:')
print(a.data)
print('output:')
print(out.data)
print('input gradients are:')
print(a.grad)
嗯,这回可以理解了,我们传入的参数,是对原来模型正常求导出来的雅克比矩阵进行线性操作,可以把我们传进的参数(设为arg)看成一个列向量,那么我们得到的结果就是:
在这个题目中,我们得到的实际是:
看起来一切完美的解释了,但是就在我刚刚打字的一刻,我意识到官方文档中说k.backward()传入的参数应该和k具有相同的维度,所以如果按上述去解释是解释不通的。哪里出问题了呢?
仔细看了一下,原来是这样的:在对雅克比矩阵进行线性操作的时候,应该把我们传进的参数(设为arg)看成一个行向量(不是列向量),那么我们得到的结果就是:
也就是:
这回我们就解释的通了。
现在我们来输出一下雅克比矩阵吧,为了不引起歧义,我们让雅克比矩阵的每个数值都不一样(一开始分析错了就是因为雅克比矩阵中有相同的数据),所以模型小改动如下:
import torch
from torch.autograd import Variable
a = Variable(torch.Tensor([[2.,4.]]),requires_grad=True)
b = torch.zeros(1,2)
b[0,0] = a[0,0] ** 2 + a[0,1]
b[0,1] = a[0,1] ** 3 + a[0,0] * 2
out = 2 * b
#其参数要传入和out维度一样的矩阵
out.backward(torch.FloatTensor([[1,0]]),retain_graph=True)
A_temp = copy.deepcopy(a.grad)
a.grad.zero_()
out.backward(torch.FloatTensor([[0,1]]))
B_temp = a.grad
print('jacobian matrix is:')
print(torch.cat( (A_temp,B_temp),0 ))
如果没问题的话咱们的雅克比矩阵应该是 [ [ 8 , 2 ] , [ 4 , 96 ] ]
来源:https://blog.csdn.net/f156207495/article/details/88727860
猜你喜欢
- <?php $monthoneday=date("Ym")."01"; $oneweekday
- 报错selenium.common.exceptions.WebDriverException: Message: Element is n
- 前言在我们的日常开发中, 常用的中间件有很多, 今天来讲一下怎么集成限流中间件, 它可以很好地用限制并发访问数来保护系统服务, 避免系统服务
- with 用法理解Overviewwith 与with之后的object一起,起到了抛出异常和单独生成一个空间让代码在空间里运行的效果。实验
- 1. 简介Python 读写文件的二进制数据需要使用到struct模块,进行C/C++与Python数据格式的转换。2. struct模块介
- 现在因为已经安装了2.6的Python,以及支持2.6的Eric4,就不想再重新安装2.5来继续配置Apache下mod_python了。后
- 一、必备技能1、logging模块的使用(1)5个日志等级/以及5个输出日志的内置函数(2)日志收集器、日志输出渠道的概念(3)如何自定义日
- 1、确认框架中安装了第三方alibabacoud控件实现代码如下上传过程中遇到任务问题,可以进行留言<?php namespace A
- 描述max() 方法返回给定参数的最大值,参数可以为序列。语法以下是 max() 方法的语法:max( x, y, z, .... )参数x
- 本文实例为大家分享了python matlibplot绘制3D图形的具体代码,供大家参考,具体内容如下1、散点图使用scatterfrom
- 前言对于PHP大家一定不陌生,但你知道PHP在CTF中是如何考察的吗,本文给大家带来的是通过PHP特性来进行CTF比赛中解题出题的知识,会介
- 1.过滤器的使用1.过滤器和测试器在Python中,如果需要对某个变量进行处理,我们可以通过函数来实现。在模板中,我们则是通过过滤器来实现的
- 什么是死锁,在Go的协程里面死锁通常就是永久阻塞了,你拿着我的东西,要我先给你然后再给我,我拿着你的东西又让你先给我,不然就不给你。我俩都这
- Pycharm默认可以识别py脚本中的SQL语句,本身很不错,但当SQL拼接时就显示的代码特别难看,找了好久,终于知道怎么关闭SQL识别功能
- 最近开始学习Python开发,“工欲善其事必先利其器”,Python程序都是用什么工具开发出来的呢。
- 本文实例讲述了Flask框架学习笔记之模板操作。分享给大家供大家参考,具体如下:flask的模板引擎是Jinja2。引入模板的好处是增加程序
- PyQt5 MDI(多文档窗口)QMidArea简介一种同时显示多个窗口的方法是,创建多个独立的窗口,这些独立的窗口被称为SDI(Singl
- 本文实例讲述了python 并发下载器实现方法。分享给大家供大家参考,具体如下:并发下载器并发下载原理from gevent import
- 前言matplotlib实际上是一套面向对象的绘图库,它所绘制的图表中的每个绘图元素,例如线条Line2D、文字Text、刻度等在内存中都有
- 前言pytest.mark.skip可以标记无法在某些平台上运行的测试功能,或者您希望失败的测试功能希望满足某些条件才执行某些测试用例,否则