tensorflow自定义激活函数实例
作者:caorui_nk 发布时间:2023-04-18 09:11:51
标签:tensorflow,自定义,激活函数
前言:因为研究工作的需要,要更改激活函数以适应自己的网络模型,但是单纯的函数替换会训练导致不能收敛。这里还有些不清楚为什么,希望有人可以给出解释。查了一些博客,发现了解决之道。下面将解决过程贴出来供大家指正。
1.背景
之前听某位老师提到说tensorflow可以在不给梯度函数的基础上做梯度下降,所以尝试了替换。我的例子时将ReLU改为平方。即原来的激活函数是 现在换成
单纯替换激活函数并不能较好的效果,在我的实验中,迭代到一定批次,准确率就会下降,最终降为10%左右保持稳定。而事实上,这中间最好的训练精度为92%。资源有限,问了对神经网络颇有研究的同学,说是激活函数的问题,然而某篇很厉害的论文中提到其精度在99%,着实有意思。之后开始研究自己些梯度函数以完成训练。
2.大概流程
首先要确定梯度函数,之后将其处理为tf能接受的类型。
2.1定义自己的激活函数
def square(x):
return pow(x, 2)
2.2 定义该激活函数的一次梯度函数
def square_grad(x):
return 2 * x
2.3 让numpy数组每一个元素都能应用该函数(全局)
square_np = np.vectorize(square)
square_grad_np = np.vectorize(square_grad)
2.4 转为tf可用的32位float型,numpy默认是64位(全局)
square_np_32 = lambda x: square_np(x).astype(np.float32)
square_grad_np_32 = lambda x: square_grad_np(x).astype(np.float32)
2.5 定义tf版的梯度函数
def square_grad_tf(x, name=None):
with ops.name_scope(name, "square_grad_tf", [x]) as name:
y = tf.py_func(square_grad_np_32, [x], [tf.float32], name=name, stateful=False)
return y[0]
2.6 定义函数
def my_py_func(func, inp, Tout, stateful=False, name=None, my_grad_func=None):
# need to generate a unique name to avoid duplicates:
random_name = "PyFuncGrad" + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(random_name)(my_grad_func)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": random_name, "PyFuncStateless": random_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
2.7 定义梯度,该函数依靠上一个函数my_py_func计算并传播
def _square_grad(op, pred_grad):
x = op.inputs[0]
cur_grad = square_grad(x)
next_grad = pred_grad * cur_grad
return next_grad
2.8 定义tf版的square函数
def square_tf(x, name=None):
with ops.name_scope(name, "square_tf", [x]) as name:
y = my_py_func(square_np_32,
[x],
[tf.float32],
stateful=False,
name=name,
my_grad_func=_square_grad)
return y[0]
3.使用
跟用其他激活函数一样,直接用就行了。input_data:输入数据。
h = square_tf(input_data)
over. 学艺不精,多多指教!
来源:https://blog.csdn.net/caorui_nk/article/details/82898200
0
投稿
猜你喜欢
- 1.什么是局部视图局部视图是在其他视图中呈现的视图。通过执行局部视图生成的HTML输出呈现在调用视图中。与视图一样,局部视图使用 .csht
- 前言:在motplotlib的学习过程中,我们使用最多的就是numpy模块。numpy 模块被称为 matplotlib 模块绘制图表伴侣。
- 很早前就遇到这个空值的属性,它既出现在 html 文档中,也出现在 xml 中,一直都回避,放之任之,反正也不影响文档的正确性。隐隐约约过了
- 本文实例为大家分享了opencv+python实现图像矫正的具体代码,供大家参考,具体内容如下需求:将斜着拍摄的文本图像进行矫正python
- 简介进行按钮进行界面的跳转,我这里面我介绍两种,一种是没有使用Qtdesigner的代码,另一种是使用Qtdesigner的代码代码1imp
- 在使用Celery统计每日访问数量的时候,发现一个任务会同时执行两次,发现同一时间内(1s内)竟然同时发送了两次任务,也就是同时产生了两个w
- 本文实例讲述了Python 操作 PostgreSQL 数据库。分享给大家供大家参考,具体如下:我使用的是 Python 3.7.0Post
- 导语《我的世界》是一款自由度极高的游戏,每个新存档的开启,就像是作为造物主的玩家在虚拟空间开辟了一个全新的宇宙。方块连接世界,云游大好河山。
- 学习前言已经完成了RNN网络的构建,但是我们对于RNN网络还有许多疑问,特别是tf.nn.dynamic_rnn函数,其具体的应用方式我们并
- 由于工作对人的眼球和精神都会带来一定的疲劳,所以在界面设计中,希望用户能够准确的关注重要的信息,而不因为用户的长期使用而流失信息。最近在看《
- 研究编码,得知GB2312编码与区位码的关系,尝试之后,得此程序。搜索,似乎没人写,故发此地。1.简述(1)GB2312标准的定义,其实就是
- 以这两个域名为例:http://www.knowsky.com/http://code.knowsky.com/这两个域名都是绑在同一个空间
- 本文实例讲述了python中Flask框架的简单用法。分享给大家供大家参考。具体如下:使用Flask框架的简单入门范例代码,如果你正学习Fl
- 一、TCP1、tcp服务器创建#创建服务器from socket import *from time import ctime #导入cti
- 示例代码: BulkStockBll bll = new BulkStockBll(); DataSet ds = bll.GetBulkS
- 在业务稳定性要求比较高的情况下,运维为能及时发现问题,有时需要对应用程序的日志进行实时分析,当符合某个条件时就立刻报警,而不是被动等待出问题
- 本文实例讲述了CI操作cookie的方法。分享给大家供大家参考,具体如下:CI 操作cookie 有三种方法,2中Ci自带的,其
- 今天发现一个使用python写的管理cisco设备的小框架tratto,可以用来批量执行命令。下载后主要有3个文件:Systems.py 定
- 在当前的Web设计中,jQuery被越来越多地应用在Web开发中,之所以jQuery收到如此程度的欢迎,除了其本身具备的优秀易读易操作的代码
- 数据库连接:<% set conn=server.createobject("adodb.connection&q