pytorch如何定义新的自动求导函数
作者:l8947943 发布时间:2021-02-10 20:14:49
标签:pytorch,自动求导,函数
pytorch定义新的自动求导函数
在pytorch中想自定义求导函数,通过实现torch.autograd.Function并重写forward和backward函数,来定义自己的自动求导运算。参考官网上的demo:传送门
直接上代码,定义一个ReLu来实现自动求导
import torch
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
# 我们使用ctx上下文对象来缓存,以便在反向传播中使用,ctx存储时候只能存tensor
# 在正向传播中,我们接收一个上下文对象ctx和一个包含输入的张量input;
# 我们必须返回一个包含输出的张量,
# input.clamp(min = 0)表示讲输入中所有值范围规定到0到正无穷,如input=[-1,-2,3]则被转换成input=[0,0,3]
ctx.save_for_backward(input)
# 返回几个值,backward接受参数则包含ctx和这几个值
return input.clamp(min = 0)
@staticmethod
def backward(ctx, grad_output):
# 把ctx中存储的input张量读取出来
input, = ctx.saved_tensors
# grad_output存放反向传播过程中的梯度
grad_input = grad_output.clone()
# 这儿就是ReLu的规则,表示原始数据小于0,则relu为0,因此对应索引的梯度都置为0
grad_input[input < 0] = 0
return grad_input
进行输入数据并测试
dtype = torch.float
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 使用torch的generator定义随机数,注意产生的是cpu随机数还是gpu随机数
generator=torch.Generator(device).manual_seed(42)
# N是Batch, H is hidden dimension,
# D_in is input dimension;D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in, device=device, dtype=dtype,generator=generator)
y = torch.randn(N, D_out, device=device, dtype=dtype, generator=generator)
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True, generator=generator)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True, generator=generator)
learning_rate = 1e-6
for t in range(500):
relu = MyRelu.apply
# 使用函数传入参数运算
y_pred = relu(x.mm(w1)).mm(w2)
# 计算损失
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# 传播
loss.backward()
with torch.no_grad():
w1 -= learning_rate * w1.grad
w2 -= learning_rate * w2.grad
w1.grad.zero_()
w2.grad.zero_()
pytorch自动求导与逻辑回归
自动求导
retain_graph设为True,可以进行两次反向传播
逻辑回归
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(10)
#========生成数据=============
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
x0 = torch.normal(mean_value*n_data,1)+bias#类别0数据
y0 = torch.zeros(sample_nums)#类别0标签
x1 = torch.normal(-mean_value*n_data,1)+bias#类别1数据
y1 = torch.ones(sample_nums)#类别1标签
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)
#==========选择模型===========
class LR(nn.Module):
def __init__(self):
super(LR,self).__init__()
self.features = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
x = self.features(x)
x = self.sigmoid(x)
return x
lr_net = LR()#实例化逻辑回归模型
#==============选择损失函数===============
loss_fn = nn.BCELoss()
#==============选择优化器=================
lr = 0.01
optimizer = torch.optim.SGD(lr_net.parameters(),lr = lr,momentum=0.9)
#===============模型训练==================
for iteration in range(1000):
#前向传播
y_pred = lr_net(train_x)#模型的输出
#计算loss
loss = loss_fn(y_pred.squeeze(),train_y)
#反向传播
loss.backward()
#更新参数
optimizer.step()
#绘图
if iteration % 20 == 0:
mask = y_pred.ge(0.5).float().squeeze() #以0.5分类
correct = (mask==train_y).sum()#正确预测样本数
acc = correct.item()/train_y.size(0)#分类准确率
plt.scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c='r',label='class0')
plt.scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c='b',label='class1')
w0,w1 = lr_net.features.weight[0]
w0,w1 = float(w0.item()),float(w1.item())
plot_b = float(lr_net.features.bias[0].item())
plot_x = np.arange(-6,6,0.1)
plot_y = (-w0*plot_x-plot_b)/w1
plt.xlim(-5,7)
plt.ylim(-7,7)
plt.plot(plot_x,plot_y)
plt.text(-5,5,'Loss=%.4f'%loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.title('Iteration:{}\nw0:{:.2f} w1:{:.2f} b{:.2f} accuracy:{:2%}'.format(iteration,w0,w1,plot_b,acc))
plt.legend()
plt.show()
plt.pause(0.5)
if acc > 0.99:
break
来源:https://blog.csdn.net/l8947943/article/details/105633826


猜你喜欢
- 如何使用Iframe实现本页提交?例:chunfeng.html< html>< head>&n
- 本博客将为各位分享Python Helium库,其是在 Selenium库基础上封装的更加高级的 Web 自动化工具,它能够通过网页端可见的
- 本节在这里主要说的是URLError还有HTTPError,以及对它们的一些处理。1.URLError首先解释下URLError可能产生的原
- 1、合并列表(extend)跟元组一样,用加号(+)将两个列表加起来即可实现合并:In [1]: x=list(range(1, 13, 2
- pprint – 美观打印作用:美观打印数据结构pprint 包含一个“美观打印机”,用于生成数据结构的一个美观视图。格式化工具会生成数据结
- 利用python,可以实现填充网页表单,从而自动登录WEB门户。(注意:以下内容只针对python3)环境准备:(1)安装python (2
- Python中将列表转换成为数据框有两种情况:第一种是两个不同列表转换成一个数据框,第二种是一个包含不同子列表的列表转换成为数据框。第一种:
- Python中的[1:]意思是去掉列表中第一个元素(下标为0),去后面的元素进行操作,以一个示例题为例,用在遍历中统计个数:题:读入N名学生
- 两种方法,一种是为表空间增加数据文件: alter tablespace users add datafile '/opt/orac
- 本文转自微信公众号:"算法与编程之美"一、前言三步搭建MUI页面主框架法包括新建含mui的HTML文件、输入mheade
- 是否看见大站的广告都是放在内容中间实现文字环绕的呢,一般普通小站广告只能放在内容开头或者结尾,也许大站的cms系统带这个功能吧,我们小站常用
- 我们在为大家整Python程序员面试试题中,发现了一些被面试官问到的最多的一些问题,以下就是本篇内容:Python是个非常受欢迎的编程语言,
- 本文实例讲述了Python使用matplotlib绘图无法显示中文问题的解决方法。分享给大家供大家参考,具体如下:在python中,默认情况
- 下面给大家介绍下mysql 8.0.16 初次登录修改密码mysql数据库初始化后初次登录需要修改密码初次登录会碰到下面这个错误ql>
- BETWEEN 运算符用于 WHERE 表达式中,选取介于两个值之间的数据范围。BETWEEN 同 AND 一起搭配使用,语法如下:WHER
- 不加的话貌似只在ie6出现过问题。出现过:改变图片地址,结果图片不见了,加载样式,但样式文件没了。就像是中断了资源的下载一样,正确时解释是
- 在我上一篇文章,我搭了一个框架,模拟了Flask网站上“@app.route(‘/')”第一条例子的行为。如果你错过了那篇“这不是魔
- 直接看下面例子my_ld = [lambda x:x*i for i in range(3)]my_list = [ld(2) for ld
- 问题记录一下出现的问题, 数据翻倍这是复现问题的代码data() { return { space: "
- 本文实例讲述了Python中的装饰器用法。分享给大家供大家参考。具体分析如下:这里还是先由stackoverflow上面的一个问题引起吧,如