浅谈Tensorflow 动态双向RNN的输出问题
作者:Michelleweii 发布时间:2022-10-16 21:30:35
tf.nn.bidirectional_dynamic_rnn()
函数:
def bidirectional_dynamic_rnn(
cell_fw, # 前向RNN
cell_bw, # 后向RNN
inputs, # 输入
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
initial_state_fw=None, # 前向的初始化状态(可选)
initial_state_bw=None, # 后向的初始化状态(可选)
dtype=None, # 初始化和输出的数据类型(可选)
parallel_iterations=None,
swap_memory=False,
time_major=False,
# 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
# 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
scope=None
)
其中,
outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设
time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。
output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。
output_state_fw和output_state_bw的类型为LSTMStateTuple。
LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。
返回值:
元组:(outputs, output_states)
这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态
def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
# rnn_size: rnn隐层节点数量
# sequence_length: 数据的序列长度
# num_layers:堆叠的rnn cell数量
# rnn_inputs: 输入tensor
# keep_prob:
'''Create the encoding layer'''
for layer in range(num_layers):
with tf.variable_scope('encode_{}'.format(layer)):
cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
rnn_inputs,sequence_length,dtype=tf.float32)
# join outputs since we are using a bidirectional RNN
enc_output = tf.concat(enc_output,2)
return enc_output,enc_state
tf.nn.dynamic_rnn()
tf.nn.dynamic_rnn的返回值有两个:outputs和state
为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。
例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。
outputs. outputs是一个tensor
如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)
如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]
state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。
这里有关于LSTM的结构问题:
来源:https://blog.csdn.net/weixin_31866177/article/details/81905226


猜你喜欢
- 这是一个自动化程度较高的程序,运行本程序后会从chrome中读取cookies用于登录人人影视签到,并且会自动添加一个windows 任务计
- 你是否对获得MySQL数据库与表的最基本命令的实际操作感到十分头疼?如果是这样子的话,以下的文章将会给你相应的解决方案,以下的文
- 本章节将为大家介绍Python循环语句的使用。Python中的循环语句有 for 和 while。Python循环语句的控制结构图如下所示:
- 本文实例讲述了PHP abstract 抽象类定义与用法。分享给大家供大家参考,具体如下:PHP抽象类应用要点:1.定义一些方法,子类必须完
- 如何在庞大的数据中高效的检索自己需要的东西?本篇内容介绍了Python做出一个大数据搜索引擎的原理和方法,以及中间进行数据分析的原理也给大家
- 一、包在我们的项目中,可能会有太多的模块但是我们不能把所有的模块这样放在这里,这样项目会乱七八糟。我们可以将所有相同类型的模块放在一个文件夹
- 一.生成器简介在python中,带yield的方法不再是普通方法,而是生成器,它的执行顺序不同与普通方法.普通方法的执行是从头到尾,最后re
- getpixel函数是用来获取图像中某一点的像素的RGB颜色值,getpixel的参数是一个坐标点。对于图象的不同的模式,getpixel函
- 简介python可以做很多事情,虽然它的强项在于进行向量运算和机器学习、深度学习等方面。但是在某些时候,我们仍然需要使用python对外提供
- 前言最近我在公司优化过几个慢查询接口的性能,总结了一些心得体会拿出来跟大家一起分享一下,希望对你会有所帮助。我们使用的数据库是Mysql8,
- 本文是OpenCV图像视觉入门之路的第11篇文章,本文详细的在图像形态学进行了图像处理,例如:腐蚀操作、膨胀操作、开闭运算、梯度运算、Top
- 本文实例为大家分享了python openCV实现摄像头获取人脸图片的具体代码,供大家参考,具体内容如下在机器学习中,训练模型需要大量图片,
- 一.基于纹理背景的图像分割该部分主要讲解基于图像纹理信息(颜色)、边界信息(反差)和背景信息的图像分割算法。在OpenCV中,GrabCut
- 程序测试是展现BUG存在的有效方式,但令人绝望的是它不足以展现其缺位。——艾兹格·迪杰斯特拉(Edsger W. Dijkstra)算法审查
- python链表的反转反转链表给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。输入:head = [1,2,3,4,5]输
- 1、转化成时间格式seconds =35400m, s = divmod(seconds, 60)h, m = divmod(m, 60)p
- 思路步骤:创建一个可以序列化的类去数据库取数据交给序列化的类处理把序列化的数据返回前端操作流程:# 安装模块pip install djan
- 我就废话不多说了,直接上代码吧!import cv2img = cv2.imread("1.jpg")b, g, r =
- 自定义用户认证系统Django 自带的用户认证系统已经可以满足大部分的情况,但是有时候我们需要某些特定的需求。Django 支持使用其他认证
- 本文实例讲述了Python机器学习算法库scikit-learn学习之决策树实现方法。分享给大家供大家参考,具体如下:决策树决策树(DTs)