浅谈Tensorflow 动态双向RNN的输出问题
作者:Michelleweii 发布时间:2022-10-16 21:30:35
tf.nn.bidirectional_dynamic_rnn()
函数:
def bidirectional_dynamic_rnn(
cell_fw, # 前向RNN
cell_bw, # 后向RNN
inputs, # 输入
sequence_length=None,# 输入序列的实际长度(可选,默认为输入序列的最大长度)
initial_state_fw=None, # 前向的初始化状态(可选)
initial_state_bw=None, # 后向的初始化状态(可选)
dtype=None, # 初始化和输出的数据类型(可选)
parallel_iterations=None,
swap_memory=False,
time_major=False,
# 决定了输入输出tensor的格式:如果为true, 向量的形状必须为 `[max_time, batch_size, depth]`.
# 如果为false, tensor的形状必须为`[batch_size, max_time, depth]`.
scope=None
)
其中,
outputs为(output_fw, output_bw),是一个包含前向cell输出tensor和后向cell输出tensor组成的元组。假设
time_major=false,tensor的shape为[batch_size, max_time, depth]。实验中使用tf.concat(outputs, 2)将其拼接。
output_states为(output_state_fw, output_state_bw),包含了前向和后向最后的隐藏状态的组成的元组。
output_state_fw和output_state_bw的类型为LSTMStateTuple。
LSTMStateTuple由(c,h)组成,分别代表memory cell和hidden state。
返回值:
元组:(outputs, output_states)
这里还有最后的一个小问题,output_states是一个元组的元组,处理方法是用c_fw,h_fw = output_state_fw和c_bw,h_bw = output_state_bw,最后再分别将c和h状态concat起来,用tf.contrib.rnn.LSTMStateTuple()函数生成decoder端的初始状态
def encoding_layer(rnn_size,sequence_length,num_layers,rnn_inputs,keep_prob):
# rnn_size: rnn隐层节点数量
# sequence_length: 数据的序列长度
# num_layers:堆叠的rnn cell数量
# rnn_inputs: 输入tensor
# keep_prob:
'''Create the encoding layer'''
for layer in range(num_layers):
with tf.variable_scope('encode_{}'.format(layer)):
cell_fw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
cell_fw = tf.contrib.rnn.DropoutWrapper(cell_fw,input_keep_prob=keep_prob)
cell_bw = tf.contrib.rnn.LSTMCell(rnn_size,initializer=tf.random_uniform_initializer(-0.1,0.1,seed=2))
cell_bw = tf.contrib.rnn.DropoutWrapper(cell_bw,input_keep_prob = keep_prob)
enc_output,enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,cell_bw,
rnn_inputs,sequence_length,dtype=tf.float32)
# join outputs since we are using a bidirectional RNN
enc_output = tf.concat(enc_output,2)
return enc_output,enc_state
tf.nn.dynamic_rnn()
tf.nn.dynamic_rnn的返回值有两个:outputs和state
为了描述输出的形状,先介绍几个变量,batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,如果输入的三个句子,那max_time对应的就是最长句子的单词数量,cell.output_size其实就是rnn cell中神经元的个数。
例子来说明其用法,假设你的RNN的输入input是[2,20,128],其中2是batch_size,20是文本最大长度,128是embedding_size,可以看出,有两个example,我们假设第二个文本长度只有13,剩下的7个是使用0-padding方法填充的。dynamic返回的是两个参数:outputs,state,其中outputs是[2,20,128],也就是每一个迭代隐状态的输出,state是由(c,h)组成的tuple,均为[batch,128]。
outputs. outputs是一个tensor
如果time_major==True,outputs形状为 [max_time, batch_size, cell.output_size ](要求rnn输入与rnn输出形状保持一致)
如果time_major==False(默认),outputs形状为 [ batch_size, max_time, cell.output_size ]
state. state是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下state的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。
这里有关于LSTM的结构问题:
来源:https://blog.csdn.net/weixin_31866177/article/details/81905226
猜你喜欢
- 本文实例为大家分享了Python OpenCV实现视频分帧的具体代码,供大家参考,具体内容如下# coding=utf-8import os
- 1、not关键词可以反转一个布尔值。>>> not TrueFalse>>>>>> n
- 什么是LSTM1、LSTM的结构我们可以看出,在n时刻,LSTM的输入有三个:当前时刻网络的输入值Xt;上一时刻LSTM的输出值ht-1;上
- 今天在看见了一堆不错的非洲的web 2.0网站的Logo,于大家一起欣赏:非洲web2.0网站的logo大部分和平时看见的web2.0网站l
- 本文实例讲述了Python常用时间操作。分享给大家供大家参考,具体如下:我们先导入必须用到的一个module>>> imp
- 在进行大量数据训练神经网络的时候,可能需要批量读取数据。于是参考了这篇文章的代码,结果发现数据一直批量循环输出,不会在数据的末尾自动停止。然
- 以下内容是针对安装tensorflow-CPU版本的。tensorflow已经支持Python3.8版本的安装。可以查看自己的Python版
- 自己写的一个自动完成效果,暂时没有ajax数据源,用静态数据代替。仅供喜欢JavaScript的同学们参考,代码如下<!DOCTYPE
- 本文实例讲述了python使用any判断一个对象是否为空的方法。分享给大家供大家参考。具体实现代码如下:>>> eth =
- 1> 如何在浏览器地址栏前添加自定义的小图标?你是不是记得有时在浏览网易网站的首页时,在地址WWW.PUTAOJIAYUAN.COM前
- 高级语言不能直接被机器所理解执行,所以都需要一个翻译的阶段,解释型语言用到的是解释器,编译型语言用到的是编译器。编译型语言通常的执行过程是:
- 生成HTML方法主要步骤只有两个:一、获取要生成的html文件的内容二、将获取的html文件内容保存为html文件我在这里主要说明的只是第一
- 我们先看一下淘宝的页面:这么一个庞然大物,该怎么切图呢?显然按照给出的方法也可以完成这项任务,但是做为前端开发的我们是否应该给自己提出更高的
- 抽象工厂模式抽象工厂模式是一种创建型设计模式, 它能创建一系列相关的对象, 而无需指定其具体类。抽象工厂定义了用于创建不同产品的接口, 但将
- 本文实例讲述了PHP中使用addslashes函数转义的安全性原理分析。分享给大家供大家参考。具体分析如下:先来看一下ECshop中adds
- 对 current_datetime 的一次赋值操作:def current_datetime(request): now =
- AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中
- 是否应该开启缓冲器? 通过脚本程序启动缓冲器 在ASP脚本的顶部包含Response.Buffer=True ,IIS就会将页面的内容缓存。
- 不正确地调用Windows应用程序接口可能会产生一些意想不到的副作用,以及潜在地对一个应用程序的代码及数据段的破坏。正确地使用一个空的32位
- 这篇是Nicholas讨论如果防止脚本失控的第二篇,主要讨论了如何重构嵌套循环、递归,以及那些在函数内部同时执行很多子操作的函数。基本的思想