pytorch自定义不可导激活函数的操作
作者:Luna_Lovegood_001 发布时间:2022-07-05 10:09:13
pytorch自定义不可导激活函数
今天自定义不可导函数的时候遇到了一个大坑。
首先我需要自定义一个函数:sign_f
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
# 我需要的module
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
def forward(self, inputs):
# 使用自定义函数
outs = sign_f(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
结果报错
TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'
我试了半天,发现自定义函数后面要加 apply ,详细见下面
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
self.r = sign_f.apply ### <-----注意此处
def forward(self, inputs):
outs = self.r(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
问题解决了!
PyTorch自定义带学习参数的激活函数(如sigmoid)
有的时候我们需要给损失函数设一个超参数但是又不想设固定阈值想和网络一起自动学习,例如给Sigmoid一个参数alpha进行调节
函数如下:
import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
验证和Sigmoid的一致性
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
[0.4379, 0.1828, 0.4629],
[0.4302, 0.1358, 0.4180]])
print(Sigmoid(input))
print(LearnSigmoid(input))
输出结果
tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])
tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)
验证权重是不是会更新
import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.LSigmoid = LearnableSigmoid()
def forward(self, x):
x = self.LSigmoid(x)
return x
net = Net()
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)
for i in range(2):
optimizer.zero_grad()
output = net(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(list(net.parameters()))
输出结果
tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]
会更新~
来源:https://blog.csdn.net/qq_43110298/article/details/115032262
猜你喜欢
- 使用:foldercleanup.py -d 10 -k c:\test\keepfile.txt c:\test表示对c:\test目录只
- 这是一个很简单的纯CSS相册滑动浏览效果,仅用一个无序列表ul结合简单的CSS就可以实现。原文中介绍的纵向滑动相册的实现方法,但是相比之下个
- 我使用“ Web 2.0设计”来形容目前占主导优势的网页设计风格, 很多人用这个词来形容:网络经济的复苏网站和用户之间更高水平的交互或一种社
- 数据库系统的安全性包括很多方面。由于很多情况下,数据库服务器容许客户机从网络上连接,因此客户机连接的安全对MySQL数据库安全有很重要的影响
- 本文较为详细的讲述了Python中常用的模块,分享给大家便于大家查阅参考之用。具体如下:1.内置模块(不用import就可以直接使用)常用内
- 一、工厂模式(Factory Pattern)工厂模式(Factory Pattern),提供了一种实例化(创建)对象的最佳方式。在工厂模式
- 相关概念并发:指一个时间段内,有几个程序在同一个cpu上运行,但是任意时刻只有一个程序在cpu上运行。比如说在一秒内cpu切换了100个进程
- Go-ethereum 解析ethersjs中产生的签名信息在签名验证的过程中,我们判断签名正确的前提是,签名解析后的公钥,和发起这次动作的
- tensorflow里面给出了一个函数用来读取图像,不过得到的结果是最原始的图像,是咩有经过解码的图像,这个函数为tf.gfile.Fast
- 本文将介绍如何基于Python Django实现验证码登录功能。验证码登录是一种常见的身份验证方式,它可以有效防止恶意攻击和机器人登录。本文
- 在进制学习时候,细心的小伙伴不免都发现unicher函数的存在,没错能够经常看到的,也就是关于进制的转化,那肯定有小伙伴要开心起来了,因为进
- <div id="d1"></div> <script > fu
- 当我们使用传统的 mysql_connect 、mysql_query方法来连接查询数据库时,如果过滤不严,就有SQL注入风险,导致网站被攻
- 只有mdf文件的数据库附加失败的修复 附加时报如下错误: 服务器: 消息 1813,级别 16,状态 2,行 1 未能打开新数据库 '
- 我们可以用鼠标把Dreamweaver的层在页面内拖动,但要全屏拖动就困难了,下面是一种实现的方法:制作步骤:一、准备图片,取名/file/
- 一、分支结构为了限定用户正规操作,也为了更好的控制程序的逻辑,必须在适当时引入条件结构。Python 条件语句是通过一条或多条语句的执行结果
- 在 MySQL 下,在进行中文模糊检索时,经常会返回一些与之不相关的记录,如查找 "%a%" 时,返回的可能有中文字符,
- List>>> [chr(i) for i in range(97,123)]['a', 'b
- 众所周知,FileSystemObject(fso)组件的强大功能及破坏性是它屡屡被免费主页提供商(那些支持ASP)的禁用的原因,我整理了一
- replace方法的语法是:stringObj.replace(rgExp, replaceText) 其中stringObj是字符串(st