pytorch中nn.RNN()汇总
作者:orangerfun 发布时间:2022-08-31 03:11:34
nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)
参数说明
input_size输入特征的维度, 一般rnn中输入的是词向量,那么 input_size 就等于一个词向量的维度
hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
num_layers网络的层数
nonlinearity激活函数
bias是否使用偏置
batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
birdirectional是否使用双向的 rnn,默认是 False
注意某些参数的默认值在标题中已注明
输入输出shape
input_shape = [时间步数, 批量大小, 特征维度] = [num_steps(seq_length), batch_size, input_dim]
在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输⼊。需要强调的是,该“输出”本身并不涉及输出层计算,形状为(时间步数, 批量大小, 隐藏单元个数);隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每⼀层的隐藏状态都会记录在该变量中;对于像⻓短期记忆(LSTM),隐藏状态是⼀个元组(h, c),即hidden state和cell state(此处普通rnn只有一个值)隐藏状态h的形状为(层数, 批量大小,隐藏单元个数)
代码
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, )
# 定义模型, 其中vocab_size = 1027, hidden_size = 256
num_steps = 35
batch_size = 2
state = None # 初始隐藏层状态可以不定义
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new.shape)
输出
torch.Size([35, 2, 256]) 1 torch.Size([1, 2, 256])
具体计算过程
H t = i n p u t ∗ W x h + H t − 1 ∗ W h h + b i a s H_t = input * W_{xh} + H_{t-1} * W_{hh} + bias Ht=input∗Wxh+Ht−1∗Whh+bias[batch_size, input_dim] * [input_dim, num_hiddens] + [batch_size, num_hiddens] *[num_hiddens, num_hiddens] +bias
可以发现每个隐藏状态形状都是[batch_size, num_hiddens], 起始输出也是一样的
注意:上面为了方便假设num_step=1
GRU/LSTM等参数同上面RNN
来源:https://blog.csdn.net/orangerfun/article/details/103934290
猜你喜欢
- 简单的合并,本例是横向合并,纵向合并可以自行调整。import xlrd import xlwtimport shutil from xlu
- 很神奇的一个晚上,居然在以前老同事的群里跟同事讨论起CSS的东西来了,不过很意外的还是有收获。在IE中常常会碰到如果将容器定位后,出现容器内
- 本文讨论 MySQL 的备份和恢复机制,以及如何维护数据表,包括最主要的两种表类型:MyISAM 和 Innodb,文中设计的 MySQL
- 我为一大型网站做了一个论坛,也顺利通过了测试。由于是第一次做这方面的数据库,我不知道比其它网站上数据库差距有多大,是不是够优化。能推荐或介绍
- 一、基本概述目前电脑上已经下载了MongoDB数据库、navicat for mongodb作为mongoDB的可视化工具,形如navica
- 本文实例为大家分享了python绘制汉诺塔的具体代码,供大家参考,具体内容如下源码:import turtleclass Stack: &n
- [前言:]ASP.NET是微软提供的最新的开发基于Web的应用程序的技术。它提供了大量的比传统ASP脚本技术的好处,包括:1)通过把UI表现
- 目录表示时间的方式1. 调用语法:2. time概述3. 时间获取4. 时间格式化(将时间以合理的方式展示出来)5. 程序计时应用6. 示例
- pycharm设置Console控制台输出自动换行解决方法File --> Settings… --> E
- 秉承着一切皆对象的理念,我们再次回头来看函数(function)。函数也是一个对象,具有属性(可以使用dir()查询)。作为对象,它还可以赋
- 锟拷码和口字码说到乱码问题就不得不提到锟斤拷,这算是非常常见的一种乱码形式,那么它到底是经过何种错误操作产生的呢?下面我们一步步探究。看一个
- 一、特效预览处理前处理后细节放大后二、程序原理将图片所在的 256 的灰度映射到相应的字符上面也就是 RGB 值转成相应的字符然
- Python中的正则表达式要用到re模块,下面先介绍一下正则表达式需要用到的特殊字符和说明常用的RegEx基础语法语法说明\d匹配一个数字字
- 最近在学习tensorflow框架,在ubuntu下用到python的一个ide --spyder,以下是常用快捷键Ctrl+1:注释/撤销
- 到现在为止,我们通过前面几篇博文的描述和分析,已经可以自动实现棋子、棋盘位置的准确判断,计算一下两个中心点之间的距离,并绘制在图形上,效果如
- 一、项目效果学校宿舍今天搬家,累麻了,突然发现展示处理的也很粗糙,就这样吧嘿嘿~~~二、核心流程1、openCV读取视频流、在每一帧图片上画
- 1.图像处理库import cv2 as cvfrom PIL import *常用的图像处理技术有图像读取,写入,绘图,图像色彩空间转换,
- 废话不多说了,直接上代码吧:#Copyright (c)2017, 东北大学软件学院学生# All rightsreserved#文件名称:
- 近来想要做一做人脸识别相关的内容,主要是想集成一个系统,看到opencv已经集成了三种性能较好的算法,但是还是想自己动手试一下,毕竟算法都比
- Python中的模块(.py文件)在创建之初会自动加载一些内建变量,__name__就是其中之一。Python模块中通常会定义很多变量和函数