PyTorch 如何检查模型梯度是否可导
作者:烟雨风渡 发布时间:2021-01-21 14:38:31
标签:PyTorch,检查,梯度
一、PyTorch 检查模型梯度是否可导
当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优化,在PyTorch框架下,我们可以使用torch.autograd.gradcheck函数来实现这一功能。
首先看一下官方文档中关于该函数的介绍:
可以看到官方文档中介绍了该函数基于何种方法,以及其参数列表,下面给出几个例子介绍其使用方法,注意:
Tensor需要是双精度浮点型且设置requires_grad = True
第一个例子:检查某一操作是否可导
from torch.autograd import gradcheck
import torch
import torch.nn as nn
inputs = torch.randn((10, 5), requires_grad=True, dtype=torch.double)
linear = nn.Linear(5, 3)
linear = linear.double()
test = gradcheck(lambda x: linear(x), inputs)
print("Are the gradients correct: ", test)
输出为:
Are the gradients correct: True
第二个例子:检查某一网络模型是否可导
from torch.autograd import gradcheck
import torch
import torch.nn as nn
# 定义神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = nn.Sequential(
nn.Linear(15, 30),
nn.ReLU(),
nn.Linear(30, 15),
nn.ReLU(),
nn.Linear(15, 1),
nn.Sigmoid()
)
def forward(self, x):
y = self.net(x)
return y
net = Net()
net = net.double()
inputs = torch.randn((10, 15), requires_grad=True, dtype=torch.double)
test = gradcheck(net, inputs)
print("Are the gradients correct: ", test)
输出为:
Are the gradients correct: True
二、Pytorch求导
1.标量对矩阵求导
验证:
>>>import torch
>>>a = torch.tensor([[1],[2],[3.],[4]]) # 4*1列向量
>>>X = torch.tensor([[1,2,3],[5,6,7],[8,9,10],[5,4,3.]],requires_grad=True) #4*3矩阵,注意,值必须要是float类型
>>>b = torch.tensor([[2],[3],[4.]]) #3*1列向量
>>>f = a.view(1,-1).mm(X).mm(b) # f = a^T.dot(X).dot(b)
>>>f.backward()
>>>X.grad #df/dX = a.dot(b^T)
tensor([[ 2., 3., 4.],
[ 4., 6., 8.],
[ 6., 9., 12.],
[ 8., 12., 16.]])
>>>a.grad b.grad # a和b的requires_grad都为默认(默认为False),所以求导时,没有梯度
(None, None)
>>>a.mm(b.view(1,-1)) # a.dot(b^T)
tensor([[ 2., 3., 4.],
[ 4., 6., 8.],
[ 6., 9., 12.],
[ 8., 12., 16.]])
2.矩阵对矩阵求导
验证:
>>>A = torch.tensor([[1,2],[3,4.]]) #2*2矩阵
>>>X = torch.tensor([[1,2,3],[4,5.,6]],requires_grad=True) # 2*3矩阵
>>>F = A.mm(X)
>>>F
tensor([[ 9., 12., 15.],
[19., 26., 33.]], grad_fn=<MmBackward>)
>>>F.backgrad(torch.ones_like(F)) # 注意括号里要加上这句
>>>X.grad
tensor([[4., 4., 4.],
[6., 6., 6.]])
注意:
requires_grad为True的数组必须是float类型
进行backgrad的必须是标量,如果是向量,必须在后面括号里加上torch.ones_like(X)
来源:https://blog.csdn.net/tszupup/article/details/112916388


猜你喜欢
- 线程池的概念是什么?在面向对象编程中,创建和销毁对象是很费时间的,因为创建一个对象要获取内存资源或者其它更多资源。在Java中更是 如此,虚
- 废话不多说,实现js登录验证码的功能需要下面两步,具体实现过程如下所示:1.jsvar code="" ; //在全局
- 一、.NET Framework Data Provider for SQL Server类型:.NET Framework类库使用:Sys
- 第一种方式阿里云、百度云服务器可用!!!yum install python3第二种方式1.下载python3.6.5的压缩包wget ht
- 用 Python 做一件很平常的事情: 打开文件, 逐行读入, 最后关掉文件; 进一步的需求是, 这也许是程序中一个可选的功能, 如果有任何
- 很实用的过滤重复数据的asp代码,函数如下:<%'**************************************
- Application-settings我们在创建tornado.web.Application的对象时,传入了第一个参数&mdas
- 通过logging模块,重写一个logging2模块,独立开启线程,将待写的日志信息异步放入队列,做到日志输出不影响主流程性能,环境pyth
- 点击进入Lombok官网下载Lombok jar包使用Lombok可能需要注意的地方(1)、当你的IDE是Idea时,要注意你的Idea是支
- 数据库对象表时存储和操作数据的逻辑结构,而数据库对象存储过程和函数,则是用来实现将一组关于表操作的sql语句当作一个整体来执行。在数据库系统
- 本文实例讲述了JavaScript简单计算人的年龄的方法。分享给大家供大家参考,具体如下:注意Date()类型转换,否则会出现NaN的错误b
- 在设计主键的时候往往需要考虑以下几点: 1.无意义性:此处无意义是从用户的角度来定义的。这种无意义在一定程度上也会减少数据库的信息冗余。常常
- 本文实例讲述了Python打开文件、文件读写操作、with方式、文件常用函数。分享给大家供大家参考,具体如下:打开文件:在python3中,
- 当我们在安装scrapy的过程中出现了Twisted错误,当我们有继续安装Twisted的时候,又继续报错,通过一系列的查询和了解,终于发现
- itchat模块官方参考文档:https://itchat.readthedocs.io/zh/latest/安装pip install i
- MYSQL官方提供了Installer方式安装MYSQL服务以及其他组件,使的Windows下安装,卸载,配置MYSQL变得特别简单。1.
- 1. 连接数据库要连接到数据库首先要导入驱动程序。例如import _ "github.com/go-sql-driver/mys
- HTML5之中一个很酷的新特性就是WebSockets,它可以让我们无需AJAX请求即可与服务器端对话。今天彬Go将让大家通过Php环境的服
- 当然有其它工具可以做这件事,但如果客户不允许你在服务器乱装东西时这个脚本就会有用了。 DECLARE @tbImportTables tab
- 由于下学期报了一个Python的入门课程所以寒假一直在自己摸索,毕竟到时候不能挂科,也是水水学分最近心血来潮打算试试爬一下百度翻译肝了一天终