pytorch自定义不可导激活函数的操作
作者:Luna_Lovegood_001 发布时间:2022-07-05 10:09:13
pytorch自定义不可导激活函数
今天自定义不可导函数的时候遇到了一个大坑。
首先我需要自定义一个函数:sign_f
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
然后我需要把它封装为一个module 类型,就像 nn.Conv2d 模块 封装 f.conv2d 一样,于是
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
# 我需要的module
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
def forward(self, inputs):
# 使用自定义函数
outs = sign_f(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
结果报错
TypeError: backward() missing 2 required positional arguments: 'ctx' and 'grad_output'
我试了半天,发现自定义函数后面要加 apply ,详细见下面
import torch
from torch.autograd import Function
import torch.nn as nn
class sign_(nn.Module):
def __init__(self, *kargs, **kwargs):
super(sign_, self).__init__(*kargs, **kwargs)
self.r = sign_f.apply ### <-----注意此处
def forward(self, inputs):
outs = self.r(inputs)
return outs
class sign_f(Function):
@staticmethod
def forward(ctx, inputs):
output = inputs.new(inputs.size())
output[inputs >= 0.] = 1
output[inputs < 0.] = -1
ctx.save_for_backward(inputs)
return output
@staticmethod
def backward(ctx, grad_output):
input_, = ctx.saved_tensors
grad_output[input_>1.] = 0
grad_output[input_<-1.] = 0
return grad_output
问题解决了!
PyTorch自定义带学习参数的激活函数(如sigmoid)
有的时候我们需要给损失函数设一个超参数但是又不想设固定阈值想和网络一起自动学习,例如给Sigmoid一个参数alpha进行调节
函数如下:
import torch.nn as nn
import torch
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
验证和Sigmoid的一致性
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
Sigmoid = nn.Sigmoid()
LearnSigmoid = LearnableSigmoid()
input = torch.tensor([[0.5289, 0.1338, 0.3513],
[0.4379, 0.1828, 0.4629],
[0.4302, 0.1358, 0.4180]])
print(Sigmoid(input))
print(LearnSigmoid(input))
输出结果
tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]])
tensor([[0.6292, 0.5334, 0.5869],
[0.6078, 0.5456, 0.6137],
[0.6059, 0.5339, 0.6030]], grad_fn=<MulBackward0>)
验证权重是不是会更新
import torch.nn as nn
import torch
import torch.optim as optim
class LearnableSigmoid(nn.Module):
def __init__(self, ):
super(LearnableSigmoid, self).__init__()
self.weight = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
self.reset_parameters()
def reset_parameters(self):
self.weight.data.fill_(1.0)
def forward(self, input):
return 1/(1 + torch.exp(-self.weight*input))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.LSigmoid = LearnableSigmoid()
def forward(self, x):
x = self.LSigmoid(x)
return x
net = Net()
print(list(net.parameters()))
optimizer = optim.SGD(net.parameters(), lr=0.01)
learning_rate=0.001
input_data=torch.randn(10,2)
target=torch.FloatTensor(10, 2).random_(8)
criterion = torch.nn.MSELoss(reduce=True, size_average=True)
for i in range(2):
optimizer.zero_grad()
output = net(input_data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
print(list(net.parameters()))
输出结果
tensor([1.], requires_grad=True)]
[Parameter containing:
tensor([0.9979], requires_grad=True)]
[Parameter containing:
tensor([0.9958], requires_grad=True)]
会更新~
来源:https://blog.csdn.net/qq_43110298/article/details/115032262


猜你喜欢
- 一、前言随着三胎政策的开放,人们对于生娃的讨论也逐渐热烈了起来,经常能够在各大社交媒体当中看到相关的话题,而随着时间慢慢地流逝,中国的首批“
- 本文实例讲述了Python实现的质因式分解算法。分享给大家供大家参考,具体如下:本来想实现一个其它的基本数学算法问题,但是发现在实现之前必须
- 如下所示:import collectionsclass Mydict(collections.UserDict):def __missin
- jQuery的选择器可谓异常强大,没有什么DOM里的任何数据能逃出它的掌心,这点是我非常喜欢的,以前获取NODE要用getElementBy
- 说明为水平排列的表单和内联表单设置可选的图标.示例<!DOCTYPE html><html lang="zh-C
- 最近开始使用Go/GoLand 在import 自定义包时出现各种状况,措手不及,大概在网上找了解决方法,几乎没说的清楚的(可能是我个人理解
- 自动等待及元素执行方法操作元素的一系列方法,只要调用了测试夹函数page,就能引出操作元素的方法:import pytestfrom pla
- 介绍本文将介绍基于OpenCV实现视频的循环播放。有以下三个步骤:首先设置一个frame的设置参数frame_counter,值为0在读帧时
- 只有pd模型文件, 打印所有节点from tensorflow.python.framework import tensor_utilfro
- 本文实例讲述了Python操作列表常用方法。分享给大家供大家参考,具体如下:使用for循环,遍历整个列表依次从列表中取出元素,存放到name
- 如果这个问题不解决,那么MySQL将无法实际处理中文。 出现这个问题的原因是因为MySQL在查询字符串时是大小写不敏感的,在编绎MySQL时
- 本文主要介绍了OpenCV 图像对比度,具有一定的参考价值,感兴趣的可以了解一下实现原理图像对比度指的是一幅图像中明暗区域最亮的白和最暗的黑
- 我这里只讲几点有关于MySQL数据库安装后遇到的个别问题 我之前安装过MYSQL好像不用手动启动服务,具体也忘记了,但我上回给公司安装的那个
- 背景众所周知,go语言可打包成目标平台二进制文件是其一大优势,如此go项目在服务器不需要配置go环境和依赖就可跑起来。操作需求:打包部署到c
- WebDriver简介selenium从2.0开始集成了webdriver的API,提供了更简单,更简洁的编程接口。selenium web
- 这篇文章主要介绍了Python Selenium参数配置方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值
- <ScriptRUNAT=SERVERLanguage=VBScript>SubApplication_OnStar
- 1、确认框架中安装了第三方alibabacoud控件实现代码如下上传过程中遇到任务问题,可以进行留言<?php namespace A
- 目录一、socketserver实现并发二、验证客户端合法性一、socketserver实现并发tcp协议的socket是只能和一个客户端通
- 一、创建矩阵的方法import numpy as np# 1直接创建mat=np.mat("1 2 3;4 5 6;7 8 9&q