pytorch中nn.RNN()汇总
作者:orangerfun 发布时间:2022-08-31 03:11:34
nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=True, batch_first=False, dropout=0, bidirectional=False)
参数说明
input_size输入特征的维度, 一般rnn中输入的是词向量,那么 input_size 就等于一个词向量的维度
hidden_size隐藏层神经元个数,或者也叫输出的维度(因为rnn输出为各个时间步上的隐藏状态)
num_layers网络的层数
nonlinearity激活函数
bias是否使用偏置
batch_first输入数据的形式,默认是 False,就是这样形式,(seq(num_step), batch, input_dim),也就是将序列长度放在第一位,batch 放在第二位
dropout是否应用dropout, 默认不使用,如若使用将其设置成一个0-1的数字即可
birdirectional是否使用双向的 rnn,默认是 False
注意某些参数的默认值在标题中已注明
输入输出shape
input_shape = [时间步数, 批量大小, 特征维度] = [num_steps(seq_length), batch_size, input_dim]
在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输⼊。需要强调的是,该“输出”本身并不涉及输出层计算,形状为(时间步数, 批量大小, 隐藏单元个数);隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每⼀层的隐藏状态都会记录在该变量中;对于像⻓短期记忆(LSTM),隐藏状态是⼀个元组(h, c),即hidden state和cell state(此处普通rnn只有一个值)隐藏状态h的形状为(层数, 批量大小,隐藏单元个数)
代码
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens, )
# 定义模型, 其中vocab_size = 1027, hidden_size = 256
num_steps = 35
batch_size = 2
state = None # 初始隐藏层状态可以不定义
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new.shape)
输出
torch.Size([35, 2, 256]) 1 torch.Size([1, 2, 256])
具体计算过程
H t = i n p u t ∗ W x h + H t − 1 ∗ W h h + b i a s H_t = input * W_{xh} + H_{t-1} * W_{hh} + bias Ht=input∗Wxh+Ht−1∗Whh+bias[batch_size, input_dim] * [input_dim, num_hiddens] + [batch_size, num_hiddens] *[num_hiddens, num_hiddens] +bias
可以发现每个隐藏状态形状都是[batch_size, num_hiddens], 起始输出也是一样的
注意:上面为了方便假设num_step=1
GRU/LSTM等参数同上面RNN
来源:https://blog.csdn.net/orangerfun/article/details/103934290


猜你喜欢
- 本文实例讲述了Laravel框架集合用法。分享给大家供大家参考,具体如下:前言集合通过 Illuminate\Support\Collect
- 问:怎样解决MySQL 5.0.16的乱码问题?答:MySQL 5.0.16的乱码问题可以用下面的方法解决:1.设置phpMyAdminLa
- 微信应用号(微信公众平台小程序,「应用号」的新称呼)终于来了!开源中国社区的博卡君通宵吐血赶稿写出的微信公众平台应用号开发教程!大家赶紧来学
- 本文实例讲述了django框架自定义模板标签(template tag)操作。分享给大家供大家参考,具体如下:django 提供了丰富的模板
- 上个月安装的pycharm,由于当时急需要使用,就直接使用的pycharm试用版,没成想,今天早上一打开,直接给我来了个下马威,不能进入了,
- Blog Posts的提交让我们从简单的开始。首页上必须有一张用户提交新的post的表单。首先我们定义一个单域表单对象(fileapp/fo
- 这个分页使用的是0游标,也就是Rs.Open Sql,Conn,0,1。但是感觉也快不了多少,10万条数据的分页时间300多豪秒之间。风格A
- 如何存放或更新缓存?缓存数据来源是预知的,我们可以预先定义哪些 mutation 是缓存相关的。我们期望这个过程更自然一点,通过某种变化自动
- 假定业务:查看在职员工的薪资的第二名的员工信息创建数据库drop database if exists emps;create databa
- 引言在 Linux 服务器上,磁盘空间的使用情况是一个非常重要的指标。如果服务器上的磁盘空间不足,可能会导致服务器崩溃,影响网站的正常运行。
- 关闭正在运行的 MySQL :[root@www.woai.it ~]# service mysql stop运行[root@www.woa
- 前言本文主要给大家介绍了关于python3对JSON的一些操作,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧。一、Dict
- 情境问题小王是一名法务专员,工作中会处理所在公司的侵权事件并向侵权方发送法务函。他会按照【法务函模板.docx】 Word 文件给【封号名单
- 本文实例讲述了Python同时向控制台和文件输出日志logging的方法。分享给大家供大家参考。具体如下:python提供了非常方便的日志模
- 很多网站注册时都会要求输入电子邮箱,其应用场景是比较广的,例如注册账号接收验证码、注册成功通知、登录通知、找回密码验证通知等。本文将介绍如何
- Python在读取文件内容时的路径问题,值得深究一下.我想讨论的重点还是在绝对路径上面.在这之前我们先看一下1:相对路径这张图演示了在相对路
- 问题背景两张表一张是用户表a(主键是int类型),一张是用户具体信息表b(用户表id字段是varchar类型)。因为要显示用户及用户信息,所
- 一、数据可视化1.pyecharts介绍官方网址:https://pyecharts.org/#/zh-cn/intro📣 概况:Echar
- 逆向最大匹配方法有正即有负,正向最大匹配算法大家可以参阅https://www.jb51.net/article/127404.htm逆向最
- 使用sql的计划任务可以处理一些特殊环境的数据,除了使用windows系统的计划任务来定时处理,不过要配合程序才行,有些事情可以直接使用sq