使用Keras实现Tensor的相乘和相加代码
作者:guofuzheng 发布时间:2021-08-04 14:10:57
前言
最近在写行为识别的代码,涉及到两个网络的融合,这个融合是有加权的网络结果的融合,所以需要对网络的结果进行加权(相乘)和融合(相加)。
最初的想法
最初的想法是用Keras.layers.Add和Keras.layers.Multiply来做,后来发现这样会报错。
rate_rgb = k.variable(np.ones((1024,),dtype='float32')*0.8)
rate_esti = k.variable(np.ones((1024,),dtype='float32')*0.2)
weight_gru1 = Multiply()([rate_rgb,gru1])
weight_gru2 = Multiply()([rate_esti,gru2])
last = Add()([weight_gru1,weight_gru2])
这么写会报错,如下
AttributeError: 'Variable' object has no attribute '_keras_history'
正确做法
后来在网上参考大神的博客,改为如下
weight_1 = Lambda(lambda x:x*0.8)
weight_2 = Lambda(lambda x:x*0.2)
weight_gru1 = weight_1(gru1)
weight_gru2 = weight_2(gru2)
last = Add()([weight_gru1,weight_gru2])
这样就没问题了。
补充知识:Keras天坑:想当然的对层的直接运算带来的问题
天坑
keras如何操作某一层的值(如让某一层的值取反加1等)?keras如何将某一层的神经元拆分以便进一步操作(如取输入的向量的第一个元素乘别的层)?keras如何重用某一层的值(如输入层和输出层乘积作为最终输出)?
这些问题都指向同一个答案,即使用Lambda层。
另外,如果想要更加灵活地操作层的话,推荐使用函数式模型写法,而不是序列式。
Keras当中,任何的操作都是以网络层为单位,操作的实现都是新添一层,不管是加减一个常数还是做乘法,或者是对两层的简单拼接。所以,将一层单独劈一半出来,是一件难事。强调,Keras的最小操作单位是Layer,每次操作的是整个batch。自然,在keras中,每个层都是对象,可以通过dir(Layer对象)来查看具有哪些属性。然而,Backend中Tensorflow的最小操作单位是Tensor,而你搞不清楚到底是Layer和Tensor时,盲目而想当然地进行层的操作,就会出问题。到底是什么?通过type和shape是看不出来的。
如果你只是想对流经该层的数据做个变换,而这个变换本身没有什么需要学习的参数,那么直接用Lambda Layer是最合适的了。
也就是说,对每一层的加减乘除都得用keras的函数,你不能简单使用形如 ‘new_layer' =1−= 1-=1−'layer'这样的表达方式来对层进行操作。
当遇到如下报错信息:
AttributeError: 'NoneType' object has no attribute '_inbound_nodes'
或
TypeError: 'Tensor' object is not callable
等等
这是就要考虑一下将程序中层的操作改成Lambda的方式表达。
使用Lambda编写自己的层
Lamda层怎么用?官方文档给了这样一个例子。
# add a x -> x^2 layer
model.add(Lambda(lambda x: x ** 2))
# add a layer that returns the concatenation
# of the positive part of the input and
# the opposite of the negative part
def antirectifier(x):
x -= K.mean(x, axis=1, keepdims=True)
x = K.l2_normalize(x, axis=1)
pos = K.relu(x)
neg = K.relu(-x)
return K.concatenate([pos, neg], axis=1)
def antirectifier_output_shape(input_shape):
shape = list(input_shape)
assert len(shape) == 2 # only valid for 2D tensors
shape[-1] *= 2
return tuple(shape)
model.add(Lambda(antirectifier,
output_shape=antirectifier_output_shape))
乍一看,有点懵逼,什么乱七八糟的。事实上,很简单,假设L0和L1是两层,你只要将你形如下面这样的表达:
L1 = F(L0);
改成
L1 = Lambda( lambda L0:F(L0) ) (L0)
即可。为了看得清楚,多加了几个空格。
事实上,无非就是将原来的变换,通过Lambda(lambda 输入:表达式)这样的方式,改成了Lambda型函数,再把输入传进去,放在尾巴上即可。
参考
https://keras-cn.readthedocs.io/en/latest/layers/core_layer/#lambda
(个人觉得这份文档某些地方比官方中文要完整许多)
keras许多简单操作,都需要新建一个层,使用Lambda可以很好完成需求。当你不知道有这个东西存在的时候,就会走不少弯路。
来源:https://blog.csdn.net/weixin_40289171/article/details/80416089


猜你喜欢
- create table [order] ( code varchar(50), createtime datetime ) --应用 us
- vue3无法使用jsx问题报错一:无法使用 JSX,除非提供了 "--jsx" 标志在Vite+Vue3.0中使用jsx
- 简介canvas 是HTML5 提供的一种新标签,它可以支持 JavaScript 在上面绘画,控制每一个像素,它经常被用来制作小游戏,接下
- 前几天网上找了一款 PC 端微信自动清理工具,用了一下,电脑释放了 30GB 的存储空间,而且不会删除文字的聊天记录,很好用,感觉很多人都用
- 许多数据科学家认为获取和清理数据的初始步骤占工作的 80%,花费大量时间来清理数据集并将它们归结为可以使用的形式。因此如果你是刚刚踏入这个领
- 背景最近项目联调的时候发现了分页查询的一个bug,分页查询总有数据查不出来或者重复查出。数据库一共14条记录。如果按照一页10条。那么第一页
- 1.问题及解决办法(1)问题:由于存储的时间戳是时间戳为GMT(格林尼治标准时间),以秒储存,但由于需要获取的是北京时间,存在时区问题。如何
- 当我们去设计数据库表结构,对操作数据库时(尤其是查表时的SQL语句),我们都需要注意数据操作的性能。这里,我们不会讲过多的SQL语句的优化,
- 最近在一个项目中遇到一个查询页面,其中一个查询条件是根据选择的年份、月以及周数显示选择的该周从几号到几号,这样一个需求。在网上搜
- 事件类型: 错误 事件来源: Service Control Manager 事件种类: 无 事件 ID: 7034 日期: 2012-11
- if (arr[i]){ &nb
- 目录1. threding模块创建线程对象2. threding模块创建多线程3. 多线程的参数传递4. 线程产生的资源竞争1. thred
- 1、代码如下:import numpy as npfrom keras.models import Sequentialfrom keras
- 摘要:神经网络的训练的主要流程包括图像输入神经网络, 得到模型的输出结果,计算模型的输出与真实值的损失, 计算损失值的梯度,最后用梯度下降算
- 级联样式表在13年前被引入,而且被广泛使用的CSS 2.1 标准在11年前被创建,显然我们现在已经与当年相差千里了。相当了不起的是期间网站开
- 经典字典使用函数dict:通过其他映射(比如其他字典)或者(键,值)这样的序列对建立字典。当然dict成为函数不是十分确切,它本质是一种类型
- 初学Python,这个问题搞了我好久,现在来分享下我的解决思路,希望可以帮到大家。先说下python引入模块的顺序:首先现在当前文件夹下查找
- 前几天,杨超越编程大赛火了,大家都在报名参加,而我也是其中的一员。在我们的项目中,我负责的是数据爬取这块,我主要是把对于杨超越 的
- #!/usr/bin/env python# -*- coding: utf-8 -*-from tkinter import *impor
- 头疼的挂马事件申请了个免费空间弄了个小站空间还可以二年多了挺稳定的只是从今年年初开始网页老莫名奇妙的被人挂马仔细检查了网站 不存在什么漏洞应