pytorch中使用LSTM详解
作者:qyhyzard 发布时间:2021-01-08 04:27:10
LSMT层
可以在troch.nn
模块中找到LSTM类
lstm = torch.nn.LSTM(*paramsters)
1、__init__方法
首先对nn.LSTM
类进行实例化,需要传入的参数如下图所示:
一般我们关注这4个:
input_size
表示输入的每个token的维度,也可以理解为一个word的embedding的维度。hidden_size
表示隐藏层也就是记忆单元C的维度,也可以理解为要将一个word的embedding维度转变成另一个大小的维度。除了C,在LSTM中输出的H的维度与C的维度是一致的。num_layers
表示有多少层LSTM,加深网络的深度,这个参数对LSTM的输出的维度是有影响的(后文会提到)。bidirectional
表示是否需要双向LSTM,这个参数也会对后面的输出有影响。
2、forward方法的输入
将数据input传入forward方法进行前向传播时有3个参数可以输入,见下图:
这里要注意的是
input
参数各个维度的意义,一般来说如果不在实例化时制定batch_first=True
,那么input
的第一个维度是输入句子的长度seq_len,第二个维度是批量的大小,第三个维度是输入句子的embedding维度也就是input_size,这个参数要与__init__
方法中的第一个参数对应。另外记忆细胞中的两个参数
h_0
和c_0
可以选择自己初始化传入也可以不传,系统默认是都初始化为0。传入的话注意维度[bidirectional * num_layers, batch_size, hidden_size]。
3、forward方法的输出
forward方法的输出如下图所示:
一般采用如下形式:
out,(h_n, c_n) = lstm(x)
out
表示在最后一层上,每一个时间步的输出,也就是句子有多长,这个out的输出就有多长;其维度为[seq_len, batch_size, hidden_size * bidirectional]。因为如果的双向LSTM,最后一层的输出会把正向的和反向的进行拼接,故需要hidden_size * bidirectional。h_n
表示的是每一层(双向算两层)在最后一个时间步上的输出;其维度为[bidirectional * num_layers, batch_size, hidden_size]
假设是双向的LSTM,且是3层LSTM,双向每个方向算一层,两个方向的组合起来叫一层LSTM,故共会有6层(3个正向,3个反向)。所以h_n是每层的输出,bidirectional * num_layers = 6。c_n
表示的是每一层(双向算两层)在最后一个时间步上的记忆单元,意义不同,但是其余均与 h_n
一样。
LSTMCell
可以在troch.nn
模块中找到LSTMCell类
lstm = torch.nn.LSTMCell(*paramsters)
它的__init__
方法的参数设置与LSTM类似,但是没有num_layers
参数,因为这就是一个细胞单元,谈不上多少层和是否双向。forward
的输入和输出与LSTM均有所不同:
其相比LSTM,输入没有了时间步的概念,因为只有一个Cell单元;输出 也没有out
参数,因为就一个Cell,out
就是h_1
,h_1
和c_1
也因为只有一个Cell单元,其没有层数上的意义,故只是一个Cell的输出的维度[batch_size, hidden_size].
代码演示如下:
rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size)
input = torch.randn(2, 3, 10) # (time_steps, batch, input_size)
hx = torch.randn(3, 20) # (batch, hidden_size)
cx = torch.randn(3, 20)
output = []
# 从输入的第一个维度也就是seq_len上遍历,每循环一次,输入一个单词
for i in range(input.size()[0]):
# 更新细胞记忆单元
hx, cx = rnn(input[i], (hx, cx))
# 将每个word作为输入的输出存起来,相当于LSTM中的out
output.append(hx)
output = torch.stack(output, dim=0)
来源:https://blog.csdn.net/qq_42961603/article/details/119638341
猜你喜欢
- 因为要写个东西用到,所以百度了一下,居然有朋友乱写,而且比较多,都没有认真测试过,只对字符可以,但是对数字就不可以,而且通用性很差,需要修改
- 我就废话不多说了,大家还是直接看代码吧~one = tf.ones_like(label)zero = tf.zeros_like(labe
- 问题你想将几个小的字符串合并为一个大的字符串解决方案如果你想要合并的字符串是在一个序列或者 iterable 中,那么最快的方式就是使用 j
- Example.asp<%@LANGUAGE="VBSCRIPT" CODEPAGE="65001&qu
- 新闻系统,相册系统可以用用哦,简单实用,有兴趣的可以自己扩充!^_^相册截图:<?xml version="1.0"
- 快速排序由于排序效率在同为O(N*logN)的几种排序方法中效率较高,因此经常被采用。该方法的基本思想是:1.先从数列中取出一个数作为基准数
- 杭州最美的季节里,淘宝无障碍访问改善小组有幸邀请到盲人在线站长——争渡读屏团队成员——杨永全同学和我们一起面对面交流网站无障碍访问方面的问题
- Python是一种高级编程语言,它在众多编程语言中,拥有极高的人气和使用率。Python中的多进程和进程池是其强大的功能之一,可以让我们更加
- 使用python + shell 编写,是一个简易solaris系统巡检程序#!/usr/bin/python -u#-*- coding:
- 前言前几天逛github发现了一个有趣的并发库-conc,其目标是:更难出现goroutine泄漏处理panic更友好并发代码可读性高从简介
- 引言本篇是以python的视角介绍相关的函数还有自我使用中的一些问题,本想在这篇之前总结一下opencv编译的全过程,但遇到了太多坑,暂时不
- 互联网是一个飞速发展的行业,任何的止步不前都会导致被淘汰,只是时间早晚的问题,所以一个公司的学习与创新能力是非常重要的,特别是对于一个年轻的
- logging日志模块:是用来记录日志的模块,一般记录用户在软件中的操作使用方法:模板直接拿来用,手动修改# logging的配置信息(模板
- 下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。import tensorf
- 在Web上使用菜单可以极大地节约页面的空间,同时也比较的符合用户从Windows上继承下来的UI操作体验。在以往的Web页菜单设计中,我们普
- 为了防止网络上日益猖獗的垃圾广告和灌水评论,大多数网站在信息发布的时候要求输入验证码。图片、文字、字母甚至还有计算题。验证码图片里的信息东颠
- pytest的setup与teardown1)pytest提供了两套互相独立的setup 与 teardown和一对相对自由的setup与t
- 若对于同一数据库实例中的两个数据库进行同步则直接对数据库表创建Trigger。SQL Server 2005的联机帮助:Trigger on
- 在windows+iis服务器上运行asp程序可能会出现数据库无法更新的情况,具体错误信息可能为: 1、Microsoft JET Data
- 一、Python图像处理PIL库1.1 转换图像格式# PIL(Python Imaging Library)from PIL import