Tensorflow与RNN、双向LSTM等的踩坑记录及解决
作者:Orion Nebula 发布时间:2021-04-29 21:25:55
1、tensorflow(不定长)文本序列读取与解析
tensorflow读取csv时需要指定各列的数据类型。
但是对于RNN这种接受序列输入的模型来说,一条序列的长度是不固定。这时如果使用csv存储序列数据,应当首先将特征序列拼接成一列。
例如两条数据序列,第一项是标签,之后是特征序列
[0, 1.1, 1.2, 2.3] 转换成 [0, '1.1_1.2_2.3']
[1, 1.0, 2.5, 1.6, 3.2, 4.5] 转换成 [1, '1.0_2.5_1.6_3.2_4.5']
这样每条数据都只包含固定两列了。
读取方式是指定第二列为字符串类型,再将字符串按照'_'分割并转换为数字。
关键的几行代码示例如下:
def readMyFileFormat(fileNameQueue):
reader = tf.TextLineReader()
key, value = reader.read(fileNameQueue)
record_defaults = [["Null"], [-1], ["Null"], ["Null"], [-1]]
phone1, seqlen, ts_diff_strseq, t_cod_strseq, userlabel = tf.decode_csv(value, record_defaults=record_defaults)
ts_diff_str = tf.string_split([ts_diff_strseq], delimiter='_')
t_cod_str = tf.string_split([t_cod_strseq], delimiter='_')
# 每个字符串转数字
Str2Float = lambda string: tf.string_to_number(string, tf.float32)
Str2Int = lambda string: tf.string_to_number(string, tf.int32)
ts_diff_seq = tf.map_fn(Str2Float, ts_diff_str.values, dtype = tf.float32) # 一定要加上dtype,且必须与fn的输出类型一致
t_cod_seq = tf.map_fn(Str2Int, t_cod_str.values, dtype = tf.int32)
2、时序建模的序列预测、序列拟合、标签预测,及输入数据格式
序列预测、拟合的“标签”都是序列本身,区别是未来时刻或者是当前时刻,当前时刻的拟合任务类似于antoencoder的reconstruction
标签预测常见于语言学建模,有单词级标签的分词与整句标签的情感分析,前者需要对每一个单词输入都要输出其分词标识,后者是取最后若干输出级联前馈神经网络分类器
keras的输入-输出对:需要将序列拆分成多个片段
序列形式:
按时间列表:static_bidirectional_rnn
多维数组:bidirectional_dynamic_rnn与stack_bidirectional_dynamic_rnn 变长双向rnn的正确使用姿势
3、多任务设置及相应的输出向量划分
对于标签预测任务,按需取输出即可
对于序列预测、拟合:
双向lstm:通常用于拟合。但如果需要捕捉动态信息,尽管需要序列完整输入,则仍可以加上正向预测与反向预测
单向lstm:拟合与预测
4、zero padding
后一般需要通过tf.boolean_mask()隔离这些零的影响,函数输入包括数据矩阵和补零位置的指示矩阵。
5、get_shape()方法
与 tf.shape() 类型区别,前者得到一个list,后者得到一个tensor
6、双向LSTM的信息瓶颈的解决
如果在时间步的最后输出,则可能会导致开始的一些字符被遗忘门给遗忘。
所以这里就对每个时间步的输出做出了处理,
主要处理有:
1、拼接:把所有的输出拼接在一起。
2、Average
3、Pooling
来源:https://zhuanlan.zhihu.com/p/36743184


猜你喜欢
- PHP 过滤器PHP 过滤器用于验证和过滤来自非安全来源的数据,比如用户的输入。什么是 PHP 过滤器PHP 过滤器用于验证和过滤来自非安全
- 本文为大家分享了Python3实现发送QQ邮件功能:html,供大家参考,具体内容如下之前已经成功发送了qq邮件。下面贴出html格式的qq
- Qt是一种基于C++的跨平台图形用户界面应用程序开发框架。如何跨平台?上到服务器上位机,下到嵌入式GUI,上天入地无所不能。Qt最早是由19
- 本文实例讲述了Python使用sax模块解析XML文件。分享给大家供大家参考,具体如下:XML样例:<?xml version=&qu
- Python字典是另一种可变容器模型,且可存储任意类型对象,如字符串、数字、元组等其他容器模型。一、创建字典字典由键和对应值成对组成。字典也
- 什么是SQL 指令植入式攻击?在设计或者维护Web网站时,你也许担心它们会受到某些卑鄙用户的恶意攻击。的确,如今的Web网站开发者们针对其站
- 前言;Python基础知识+结构+数据类型Python基础学习列表+元组+字典+集合Python基础学习函数+模块+类今天给大家分享的是第四
- 管理SQL Server内在的帐户和密码时,我们很容易认为这一切都相当的安全。但实际上并非如此。在这里,我们列出了一些对于SQL Serve
- $str = '中华人民共和国123456789abcdefg'; echo preg_match("/^[u4e
- 关于DHT协议DHT协议作为BT协议的一个辅助,是非常好玩的。它主要是为了在BT正式下载时得到种子或者BT资源。传统的网络,需要一台中央服务
- 一、常见反爬机制及其破解方式封禁IP,使用cookie等前面文章已经讲过现在主要将下面的:~ 验证码 —> 文字验证码 —> O
- 几天前,想把上个月校园招聘的餐旅费报销一下。结果在公司内网的报销系统折腾了三个半小时才搞定。看看自己报销的金额:802块。觉得挺无奈,花了三
- 何为样本分布不均:样本分布不均衡就是指样本差异非常大,例如共1000条数据样本的数据集中,其中占有10条样本分类,其特征无论如何你和也无法实
- 完整代码<!doctype html><html lang="en"><head>
- 刚开始,根据我的想法,这个很简单嘛,上sql语句delete from zqzrdp where tel in (select min(dp
- 本文实例讲述了ThinkPHP中url隐藏入口文件后接收alipay传值的方法。分享给大家供大家参考。具体方法如下:现在公司项目的需求变化多
- 表还是total_sales添加一项表:SQL语句:SELECT * from( SELECT a1.N
- OS库提供通用的,基本的操作系统交互功能。-OS库是Python标准库,包含几百个函数-常用路径操作,进程管理,环境参数等几类-路径操作:
- 看到这张照片,我们一眼能够看到天宏(图中这位UED俊男)的眼睛。我们能从他的表情里读出一些他的性格。一张好的摄影作品,最重要的一点,就是这个
- XML是一项热门的技术。它之所以能够引起人们的兴趣,一个主要的原因在于它十分的简单,人们可以很容易地理解和使用它。每一个程序员都能轻易地看懂