python循环神经网络RNN函数tf.nn.dynamic_rnn使用
作者:Bubbliiiing 发布时间:2022-08-28 15:42:24
学习前言
已经完成了RNN网络的构建,但是我们对于RNN网络还有许多疑问,特别是tf.nn.dynamic_rnn函数,其具体的应用方式我们并不熟悉,查询了一下资料,我心里的想法是这样的。
tf.nn.dynamic_rnn的定义
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
cell:上文所定义的lstm_cell。
inputs:RNN输入。如果time_major==false(默认),则必须是如下shape的tensor:[batch_size,max_time,…]或此类元素的嵌套元组。如果time_major==true,则必须是如下形状的tensor:[max_time,batch_size,…]或此类元素的嵌套元组。
sequence_length:Int32/Int64矢量大小。用于在超过批处理元素的序列长度时复制通过状态和零输出。因此,它更多的是为了性能而不是正确性。
initial_state:上文所定义的_init_state。
dtype:数据类型。
parallel_iterations:并行运行的迭代次数。那些不具有任何时间依赖性并且可以并行运行的操作将是。这个参数用时间来交换空间。值>>1使用更多的内存,但花费的时间更少,而较小的值使用更少的内存,但计算需要更长的时间。
time_major:输入和输出tensor的形状格式。如果为True,这些张量的形状必须是[max_time,batch_size,depth]。如果为False,这些张量的形状必须是[batch_size,max_time,depth]。使用time_major=true会更有效率,因为它可以避免在RNN计算的开始和结束时进行换位。但是,大多数TensorFlow数据都是批处理主数据,因此默认情况下,此函数为False。
scope:创建的子图的可变作用域;默认为“RNN”。
其返回值为outputs,states。
outputs
:RNN的最后一层的输出,是一个tensor。如果为time_major== False,则它的shape为[batch_size,max_time,cell.output_size]。如果为time_major== True,则它的shape为[max_time,batch_size,cell.output_size]。
states
:是每一层的最后一个step的输出,是一个tensor。state是最终的状态,也就是序列中最后一个cell输出的状态。一般情况下states的形状为 [batch_size, cell.output_size],但当输入的cell为BasicLSTMCell时,states的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。
tf.nn.dynamic_rnn的使用举例
单层实验
我们首先使用单层的RNN进行实验。
使用的代码为:
import tensorflow as tf
import numpy as np
n_steps = 2 #两个step
n_inputs = 3 #每个input是三维
n_nerve = 4 #神经元个数
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_nerve)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
X_batch = np.array([[[0, 1, 2], [1, 2, 3]],
[[3, 4, 5], [4, 5, 6]],
[[5, 6, 7], [6, 7, 8]],
[[7, 8, 9], [8, 9, 10]]])
with tf.Session() as sess:
sess.run(init)
outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch})
print("outputs:", outputs_val)
print("states:", states_val)
输出的log为:
outputs: [[[0.92146313 0.6069534 0.24989243 0.9305415 ]
[0.9234855 0.8470011 0.7865616 0.99935764]]
[[0.9772771 0.9713368 0.99483156 0.9999987 ]
[0.9753329 0.99538314 0.9988139 1. ]]
[[0.9901842 0.99558043 0.9998626 1. ]
[0.989398 0.9992842 0.9999691 1. ]]
[[0.99577546 0.9993256 0.99999636 1. ]
[0.9954579 0.9998903 0.99999917 1. ]]]
states: [[0.9234855 0.8470011 0.7865616 0.99935764]
[0.9753329 0.99538314 0.9988139 1. ]
[0.989398 0.9992842 0.9999691 1. ]
[0.9954579 0.9998903 0.99999917 1. ]]
Xin的shape是[batch_size = 4, max_time = 2, depth = 3]。
outputs的shape是[batch_size = 4, max_time = 2, cell.output_size = 4]。
states的shape是[batch_size = 4, cell.output_size = 4]
在time_major = False的时候:
Xin、outputs、states的第一维,都是batch_size,即用于训练的batch的大小。
Xin、outputs的第二维,都是max_time,在本文中对应着RNN的两个step。
outputs、states的最后一维指的是每一个RNN的Cell的输出,本文的RNN的Cell的n_nerve为4,所以cell.output_size = 4。Xin的最后一维指的是每一个输入样本的维度。
outputs对应的是RNN的最后一层的输出,states对应的是每一层的最后一个step的输出。在RNN的层数仅1层的时候,states的输出对应为outputs最后的step的输出。
多层实验
接下来我们使用两层的RNN进行实验。
使用的代码为:
import tensorflow as tf
import numpy as np
n_steps = 2 #两个step
n_inputs = 3 #每个input是三维
n_nerve = 4 #神经元个数
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
#定义多层
layers = [tf.nn.rnn_cell.BasicRNNCell(num_units=n_nerve) for i in range(2)]
multi_layer_cell = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(multi_layer_cell, X, dtype=tf.float32)
init = tf.global_variables_initializer()
X_batch = np.array([[[0, 1, 2], [1, 2, 3]],
[[3, 4, 5], [4, 5, 6]],
[[5, 6, 7], [6, 7, 8]],
[[7, 8, 9], [8, 9, 10]]])
with tf.Session() as sess:
sess.run(init)
outputs_val, states_val = sess.run([outputs, states], feed_dict={X: X_batch})
print("outputs:", outputs_val)
print("states:", states_val)
输出的log为:
outputs: [[[-0.577939 -0.3657474 -0.21074213 0.8188577 ]
[-0.67090076 -0.47001836 -0.40080917 0.6026697 ]]
[[-0.72777444 -0.36500326 -0.7526911 0.86113644]
[-0.7928404 -0.6413429 -0.61007065 0.787065 ]]
[[-0.7537433 -0.35850585 -0.83090436 0.8573037 ]
[-0.82016116 -0.6559162 -0.7360482 0.7915131 ]]
[[-0.7597004 -0.35760364 -0.8450942 0.8567379 ]
[-0.8276395 -0.6573326 -0.7727142 0.7895221 ]]]
states: (array([[-0.71645427, -0.0585744 , 0.95318353, 0.8424729 ],
[-0.99845 , -0.5044571 , 0.9955299 , 0.9750488 ],
[-0.99992913, -0.8408632 , 0.99885863, 0.9932366 ],
[-0.99999577, -0.9672 , 0.9996866 , 0.99814796]],
dtype=float32),
array([[-0.67090076, -0.47001836, -0.40080917, 0.6026697 ],
[-0.7928404 , -0.6413429 , -0.61007065, 0.787065 ],
[-0.82016116, -0.6559162 , -0.7360482 , 0.7915131 ],
[-0.8276395 , -0.6573326 , -0.7727142 , 0.7895221 ]],
dtype=float32))
可以看出来outputs对应的是RNN的最后一层的输出,states对应的是每一层的最后一个step的输出,在完成了两层的定义后,outputs的shape并没有变化,而states的内容多了一层,分别对应RNN的两层输出。
state中最后一层输出对应着outputs最后一步的输出。
来源:https://blog.csdn.net/weixin_44791964/article/details/98480738
猜你喜欢
- 购物车是电子商务网站中不可缺少的组成部分,但目前大多数购物车只能作为一个顾客选中商品的展示,客户端无法将购物车里的内容提取出来满足自己事务处
- 一、代码注释介绍注释就是对代码的解释和说明,其目的是让人们能够更加轻松地了解代码。注释是编写程序时,写程序的人给一个语句、程序段、函数等的解
- 本文实例讲述了php逐行读取txt文件写入数组的方法。分享给大家供大家参考。具体如下:假设有user.txt文件如下:user01user0
- 之前总结过flask里的基础知识,现在来总结下flask里的前后端数据交互的知识,这里用的是Ajax一、 post方法1、post方法的位置
- 我们知道,在js中,当object作为参数传递到函数中进行处理后,实际上是修改了传入的对象本身(或者说是对象的引用),但很多时候我们并不希望
- 1、引言续上一篇《一行代码,导入Python所有库》不知道是不是都跟小鱼一样,把剩下的时间来学(撩)习(妹)。为了体现小鱼在懒上的造就,小鱼
- 在防止sql注入这些细节出现问题的一般是那些大意的程序员或者是新手程序员,他们由于没有对用户提交过来的数据进行一些必要的过滤,从而导致了给大
- 关联模型(多对多)多对多关系(抽象)例:一篇文章可能有多个关键词,一个关键词可能被多个文章使用。 关键词表:字段id主键字段keyword关
- 前言help(argparse)查看说明文档,“argparse - Command-line parsing libr
- 懒加载是一种编程范式,它推迟加载操作,直到不得不这样做。通常,当操作开销很大,需要耗费大量时间或空间时,惰性求值是首选实现。例如,在 Pyt
- 从4年之前什么都不知道,到现在对代码的一网情深,感谢无忧的兄弟姐妹的帮助,感谢无忧给我们提供了这么好的交流平台。现将最近几天捣鼓的asp封装
- 前言前言:想写这个代码的原因是因为实习的时候需要根据表格名创建对应的文件夹,如果只是很少个数文件夹的话,ctrl+shift+n还可以接受吧
- 今天学习了用python生成仿真数据的一些基本方法和技巧,写成博客和大家分享一下。 本篇博客主
- 1)利用eval可以将字典格式的字符串与字典户转》》》mstr = '{"name":"yct&quo
- 我们怎样才能了解用户需求呢?大家都知道可用性测试、调查问卷之类与用户进行沟通的途径,这些方法各有各的利弊,如果逐一分析的话,恐怕至少要分成三
- 正则表达式是处理字符串的强大工具。作为一个概念而言,正则表达式对于Python来说并不是独有的。但是,Python中的正则表达式在实际使用过
- 在我们使用查询语句的时候,经常要返回前几条或者中间某几行数据,这个时候怎么办呢?不用担心, mysql已经为我们提供了这样一个功
- 本文实例为大家分享了python文件写入write()的操作的具体代码,供大家参考,具体内容如下filename = 'pragra
- 安装好mysql后,在终端输入 mysql -u root -p 按回车,输入密码后提示access denied......ues pas
- 现在大家学习python掌握内容了解太多太多,但是最重要的不是掌握了解算法的使用,而是了解算法原理远比使用算法命令更重要,现在大家了解算法应