tensorflow自定义激活函数实例
作者:caorui_nk 发布时间:2023-04-18 09:11:51
标签:tensorflow,自定义,激活函数
前言:因为研究工作的需要,要更改激活函数以适应自己的网络模型,但是单纯的函数替换会训练导致不能收敛。这里还有些不清楚为什么,希望有人可以给出解释。查了一些博客,发现了解决之道。下面将解决过程贴出来供大家指正。
1.背景
之前听某位老师提到说tensorflow可以在不给梯度函数的基础上做梯度下降,所以尝试了替换。我的例子时将ReLU改为平方。即原来的激活函数是 现在换成
单纯替换激活函数并不能较好的效果,在我的实验中,迭代到一定批次,准确率就会下降,最终降为10%左右保持稳定。而事实上,这中间最好的训练精度为92%。资源有限,问了对神经网络颇有研究的同学,说是激活函数的问题,然而某篇很厉害的论文中提到其精度在99%,着实有意思。之后开始研究自己些梯度函数以完成训练。
2.大概流程
首先要确定梯度函数,之后将其处理为tf能接受的类型。
2.1定义自己的激活函数
def square(x):
return pow(x, 2)
2.2 定义该激活函数的一次梯度函数
def square_grad(x):
return 2 * x
2.3 让numpy数组每一个元素都能应用该函数(全局)
square_np = np.vectorize(square)
square_grad_np = np.vectorize(square_grad)
2.4 转为tf可用的32位float型,numpy默认是64位(全局)
square_np_32 = lambda x: square_np(x).astype(np.float32)
square_grad_np_32 = lambda x: square_grad_np(x).astype(np.float32)
2.5 定义tf版的梯度函数
def square_grad_tf(x, name=None):
with ops.name_scope(name, "square_grad_tf", [x]) as name:
y = tf.py_func(square_grad_np_32, [x], [tf.float32], name=name, stateful=False)
return y[0]
2.6 定义函数
def my_py_func(func, inp, Tout, stateful=False, name=None, my_grad_func=None):
# need to generate a unique name to avoid duplicates:
random_name = "PyFuncGrad" + str(np.random.randint(0, 1E+8))
tf.RegisterGradient(random_name)(my_grad_func)
g = tf.get_default_graph()
with g.gradient_override_map({"PyFunc": random_name, "PyFuncStateless": random_name}):
return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
2.7 定义梯度,该函数依靠上一个函数my_py_func计算并传播
def _square_grad(op, pred_grad):
x = op.inputs[0]
cur_grad = square_grad(x)
next_grad = pred_grad * cur_grad
return next_grad
2.8 定义tf版的square函数
def square_tf(x, name=None):
with ops.name_scope(name, "square_tf", [x]) as name:
y = my_py_func(square_np_32,
[x],
[tf.float32],
stateful=False,
name=name,
my_grad_func=_square_grad)
return y[0]
3.使用
跟用其他激活函数一样,直接用就行了。input_data:输入数据。
h = square_tf(input_data)
over. 学艺不精,多多指教!
来源:https://blog.csdn.net/caorui_nk/article/details/82898200


猜你喜欢
- 有的时候需要将两组数据,比如特征和标签放在一起随机打乱, 但是又想记录这种打乱的顺序,那么该怎么做呢?下面是一个很好的方法:b = [1,
- 前两天有一位网友问我一个关于Javascript中++操作符的问题,他的代码大致是这样的ADS.addEvent(window,'c
- 今天写了一个放迅雷焦点广告的效果,还请大家多多指正,先附上效果图一张:相关文章:迅雷首页新闻图片轮播效果js源码首先是JS代码部分,之前一定
- Background高斯噪声,顾名思义是指服从高斯分布(正态分布)的一类噪声。有的时候我们需要向标准数据中加入合适的高斯噪声让数据更加符合实
- 那么,现在如果给出一个权限编号,要去检索出用后这个权限的用户集合,就会需要在逗号分隔的多个权限编号中去匹配给出的这个权限编号。如果使用lik
- 我们都知道有很多的非常著名的注册服务器,例如: Consul、ZooKeeper、etcd,甚至借助于redis完成服务注册发现。但是本篇文
- 本文实例讲述了Python实现多条件筛选目标数据功能。分享给大家供大家参考,具体如下:python中提供了一些数据过滤功能,可以使用内建函数
- <!--#include file="config.asp" -->&nbs
- 我就废话不多说了,大家还是直接看代码吧~package mainimport ("fmt""reflect&q
- 我在按照 Byte of python一步步的学习Python, 在学到‘解决方案'的时候,原文的实例 “backup_ver1.p
- TypeScriptTypeScript是一种由微软开发的自由和开源的编程语言。它是JavaScript的一个超集,而且本质上向这个语言添加
- 阅读上一篇:FrontPage2002简明教程四:网页超级链接 一、三种添加CSS的方式 在FrontPage 2002里可以通过三种方式给
- validator自定义验证及易错点validator自定义验证element中Form 组件提供了表单验证的功能,只需要通过 rules
- 文件名称:ByVal.aspByRef.asp具体代码:<%Sub TestMain()Dim A : A=5Call TestBy(
- 创建列表sample_list = ['a',1,('a','b')]Python 列表操作
- 这篇文章主要介绍了python处理RSTP视频流过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 一、临时表实现分步处理1.概述当需要的结果需要经过多次处理后才能最终得到我们需要的结果时,就可以使用临时表,这里临时表就起到了一个中间处理的
- 我们在使用GPU资源进行训练的时候,可能会发生资源耗尽的情况,那么在在这种情况,我们需要对GPU的资源进行合理的安排,具体使用办法如下:框架
- OpenCV的全称是:Open Source Computer Vision Library。OpenCV是一个基于(开源)发行的跨平台计算
- 一. 日志传送概述SQL Server使用日志传送,可以自动将主服务器的事务日志备份发送到一个或多个辅助数据库上。事务日志备份分别应用于每个