Keras自定义实现带masking的meanpooling层方式
作者:蕉叉熵 发布时间:2021-06-23 03:29:47
Keras确实是一大神器,代码可以写得非常简洁,但是最近在写LSTM和DeepFM的时候,遇到了一个问题:样本的长度不一样。对不定长序列的一种预处理方法是,首先对数据进行padding补0,然后引入keras的Masking层,它能自动对0值进行过滤。
问题在于keras的某些层不支持Masking层处理过的输入数据,例如Flatten、AveragePooling1D等等,而其中meanpooling是我需要的一个运算。例如LSTM对每一个序列的输出长度都等于该序列的长度,那么均值运算就只应该除以序列长度,而不是padding后的最长长度。
例如下面这个 3x4 大小的张量,经过补零padding的。我希望做axis=1的meanpooling,则第一行应该是 (10+20)/2,第二行应该是 (10+20+30)/3,第三行应该是 (10+20+30+40)/4。
Keras如何自定义层
在 Keras2.0 版本中(如果你使用的是旧版本请更新),自定义一个层的方法参考这里。具体地,你只要实现三个方法即可。
build(input_shape) : 这是你定义层参数的地方。这个方法必须设self.built = True,可以通过调用super([Layer], self).build()完成。如果这个层没有需要训练的参数,可以不定义。
call(x) : 这里是编写层的功能逻辑的地方。你只需要关注传入call的第一个参数:输入张量,除非你希望你的层支持masking。
compute_output_shape(input_shape) : 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。
下面是一个简单的例子:
from keras import backend as K
from keras.engine.topology import Layer
import numpy as np
class MyLayer(Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.output_dim),
initializer='uniform',
trainable=True)
super(MyLayer, self).build(input_shape) # Be sure to call this somewhere!
def call(self, x):
return K.dot(x, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
Keras自定义层如何允许masking
观察了一些支持masking的层,发现他们对masking的支持体现在两方面。
在 __init__ 方法中设置 supports_masking=True。
实现一个compute_mask方法,用于将mask传到下一层。
部分层会在call中调用传入的mask。
自定义实现带masking的meanpooling
假设输入是3d的。首先,在__init__方法中设置self.supports_masking = True,然后在call中实现相应的计算。
from keras import backend as K
from keras.engine.topology import Layer
import tensorflow as tf
class MyMeanPool(Layer):
def __init__(self, axis, **kwargs):
self.supports_masking = True
self.axis = axis
super(MyMeanPool, self).__init__(**kwargs)
def compute_mask(self, input, input_mask=None):
# need not to pass the mask to next layers
return None
def call(self, x, mask=None):
if mask is not None:
mask = K.repeat(mask, x.shape[-1])
mask = tf.transpose(mask, [0,2,1])
mask = K.cast(mask, K.floatx())
x = x * mask
return K.sum(x, axis=self.axis) / K.sum(mask, axis=self.axis)
else:
return K.mean(x, axis=self.axis)
def compute_output_shape(self, input_shape):
output_shape = []
for i in range(len(input_shape)):
if i!=self.axis:
output_shape.append(input_shape[i])
return tuple(output_shape)
使用举例:
from keras.layers import Input, Masking
from keras.models import Model
from MyMeanPooling import MyMeanPool
data = [[[10,10],[0, 0 ],[0, 0 ],[0, 0 ]],
[[10,10],[20,20],[0, 0 ],[0, 0 ]],
[[10,10],[20,20],[30,30],[0, 0 ]],
[[10,10],[20,20],[30,30],[40,40]]]
A = Input(shape=[4,2]) # None * 4 * 2
mA = Masking()(A)
out = MyMeanPool(axis=1)(mA)
model = Model(inputs=[A], outputs=[out])
print model.summary()
print model.predict(data)
结果如下,每一行对应一个样本的结果,例如第一个样本只有第一个时刻有值,输出结果是[10. 10. ],是正确的。
[[10. 10.]
[15. 15.]
[20. 20.]
[25. 25.]]
在DeepFM中,每个样本都是由ID构成的,多值field往往会导致样本长度不一的情况,例如interest这样的field,同一个样本可能在该field中有多项取值,毕竟每个人的兴趣点不止一项。
采取padding的方法将每个field的特征补长到最长的长度,则数据尺寸是 [batch_size, max_timestep],经过Embedding为每个样本的每个特征ID配一个latent vector,数据尺寸将变为 [batch_size, max_timestep,latent_dim]。
我们希望每一个field的Embedding之后的尺寸为[batch_size, latent_dim],然后进行concat操作横向拼接,所以这里就可以使用自定义的MeanPool层了。希望能给大家一个参考,也希望大家多多支持脚本之家。
来源:https://blog.csdn.net/songbinxu/article/details/80148856


猜你喜欢
- 本来想着做一个将图片识别为文字的小功能,本想到Google上面第一页全是各种收费平台的广告。这些平台提供的基本都是让我们通过调用相关的三方接
- 本文实例讲述了Python生成8位随机字符串的方法。分享给大家供大家参考,具体如下:#!/usr/bin/env python# -*- c
- 1.安装背景最近想放弃windows编程环境,转到linux。原因就一个字:潮从格式化所有硬盘,到安装win10/ubuntu18.04双系
- 如何解决bootStrapValidator bootStrap-select验证不可用,只要三步:思路:把多选下拉框的选中值,赋给一个隐藏
- 列表解析 在需要改变列表而不是需要新建某列表时,可以使用列表解析。列表解析表达式为: [expr for iter_var in itera
- 安装流程:前期准备工作--->安装ORACLE软件--->安装升级补丁--->安装odbc创建数据库--->安装监听
- 一、缺失值的处理方法由于各种各样的原因,真实世界中的许多数据集都包含缺失数据,这些数据经常被编码成空格、nans或者是其他的占位符。但是这样
- 安装方法1)、apt-ge安装sudo apt-get install Flask-SQLAlchemy2)、下载安装包进行安装# 安装后可
- 前言:《flappy bird》是一款由来自越南的独立游戏开发者Dong Nguyen所开发的作品,游戏于2013年5月24日上线,并在20
- 这篇文章主要介绍了python with (as)语句实例详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 一、简介抠图是用PS?用魔棒和快速选择工具?遇到复杂背景怎么办?最近发现一个神奇的工具——Remove Image Backgroundht
- 本文介绍了纯python进行矩阵的相乘运算的方法示例,分享给大家,具体如下:def matrixMultiply(A, B):
- 一、所用知识点:1. for循环与if判断的结合2. %s占位符的使用3. 辅助标志的使用(标志位)4. break的使用二、代码示例:
- 正则表达式处理花括号内容替换赋值@Test public void replaceStr() { &
- 解决golang编译提示dial tcp 172.217.160.113:443: connectex: A connection atte
- 1.在查询结果中显示列名: a.用as关键字:select name as '姓名' from students order
- 在一个项目中,制作呃echart图表的时候,遇到一个需求,需要从后端接口获取数据----售票员的姓名和业绩所以需要在订单表中,获取不同售票员
- 目录建表查看数据库文件:插入查询删除补充:Mysql自动按月表分区MySQL单表数据量,建议不要超过2000W行,否则会对性能有较大影响。最
- 现将几种主要情况进行小结: 一、如何输入NULL值 如果不输入null值,当时间为空时,会默认写入"1900-01-01"
- 看了一个月的文档和资料以后,终于让我参与到项目中来了,哈哈,痛快!虽然只是让我解决一个小问题,不过有活干就是好。在写代码的过程中遇到了一个小