PyTorch深度学习LSTM从input输入到Linear输出
作者:Cyril_KI 发布时间:2022-04-03 23:11:32
LSTM介绍
关于LSTM的具体原理,可以参考:
https://www.jb51.net/article/178582.htm
https://www.jb51.net/article/178423.htm
系列文章:
PyTorch搭建双向LSTM实现时间序列负荷预测
PyTorch搭建LSTM实现多变量多步长时序负荷预测
PyTorch搭建LSTM实现多变量时序负荷预测
PyTorch搭建LSTM实现时间序列负荷预测
LSTM参数
关于nn.LSTM的参数,官方文档给出的解释为:
总共有七个参数,其中只有前三个是必须的。由于大家普遍使用PyTorch的DataLoader来形成批量数据,因此batch_first也比较重要。LSTM的两个常见的应用场景为文本处理和时序预测,因此下面对每个参数我都会从这两个方面来进行具体解释。
input_size:在文本处理中,由于一个单词没法参与运算,因此我们得通过Word2Vec来对单词进行嵌入表示,将每一个单词表示成一个向量,此时input_size=embedding_size。
比如每个句子中有五个单词,每个单词用一个100维向量来表示,那么这里input_size=100;
在时间序列预测中,比如需要预测负荷,每一个负荷都是一个单独的值,都可以直接参与运算,因此并不需要将每一个负荷表示成一个向量,此时input_size=1。
但如果我们使用多变量进行预测,比如我们利用前24小时每一时刻的[负荷、风速、温度、压强、湿度、天气、节假日信息]来预测下一时刻的负荷,那么此时input_size=7。
hidden_size:隐藏层节点个数。可以随意设置。
num_layers:层数。nn.LSTMCell与nn.LSTM相比,num_layers默认为1。
batch_first:默认为False,意义见后文。
Inputs
关于LSTM的输入,官方文档给出的定义为:
可以看到,输入由两部分组成:input、(初始的隐状态h_0,初始的单元状态c_0)
其中input:
input(seq_len, batch_size, input_size)
seq_len:在文本处理中,如果一个句子有7个单词,则seq_len=7;在时间序列预测中,假设我们用前24个小时的负荷来预测下一时刻负荷,则seq_len=24。
batch_size:一次性输入LSTM中的样本个数。在文本处理中,可以一次性输入很多个句子;在时间序列预测中,也可以一次性输入很多条数据。
input_size:见前文。
(h_0, c_0):
h_0(num_directions * num_layers, batch_size, hidden_size)
c_0(num_directions * num_layers, batch_size, hidden_size)
h_0和c_0的shape一致。
num_directions:如果是双向LSTM,则num_directions=2;否则num_directions=1。
num_layers:见前文。
batch_size:见前文。
hidden_size:见前文。
Outputs
关于LSTM的输出,官方文档给出的定义为:
可以看到,输出也由两部分组成:otput、(隐状态h_n,单元状态c_n)
其中output的shape为:
output(seq_len, batch_size, num_directions * hidden_size)
h_n和c_n的shape保持不变,参数解释见前文。
batch_first
如果在初始化LSTM时令batch_first=True,那么input和output的shape将由:
input(seq_len, batch_size, input_size)
output(seq_len, batch_size, num_directions * hidden_size)
变为:
input(batch_size, seq_len, input_size)
output(batch_size, seq_len, num_directions * hidden_size)
即batch_size提前。
案例
简单搭建一个LSTM如下所示:
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, batch_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_size = output_size
self.num_directions = 1 # 单向LSTM
self.batch_size = batch_size
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)
def forward(self, input_seq):
h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
seq_len = input_seq.shape[1] # (5, 30)
# input(batch_size, seq_len, input_size)
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 30, 1)
# output(batch_size, seq_len, num_directions * hidden_size)
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
pred = self.linear(output) # pred(150, 1)
pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
pred = pred[:, -1, :] # (5, 1)
return pred
其中定义模型的代码为:
self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True)
self.linear = nn.Linear(self.hidden_size, self.output_size)
我们加上具体的数字:
self.lstm = nn.LSTM(self.input_size=1, self.hidden_size=64, self.num_layers=5, batch_first=True)
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
再看前向传播:
def forward(self, input_seq):
h_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
c_0 = torch.randn(self.num_directions * self.num_layers, self.batch_size, self.hidden_size).to(device)
seq_len = input_seq.shape[1] # (5, 30)
# input(batch_size, seq_len, input_size)
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 30, 1)
# output(batch_size, seq_len, num_directions * hidden_size)
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
pred = self.linear(output) # (150, 1)
pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
pred = pred[:, -1, :] # (5, 1)
return pred
假设用前30个预测下一个,则seq_len=30,batch_size=5,由于设置了batch_first=True,因此,输入到LSTM中的input的shape应该为:
input(batch_size, seq_len, input_size) = input(5, 30, 1)
但实际上,经过DataLoader处理后的input_seq为:
input_seq(batch_size, seq_len) = input_seq(5, 30)
(5, 30)表示一共5条数据,每条数据的维度都为30。为了匹配LSTM的输入,我们需要对input_seq的shape进行变换:
input_seq = input_seq.view(self.batch_size, seq_len, 1) # (5, 30, 1)
然后将input_seq送入LSTM:
output, _ = self.lstm(input_seq, (h_0, c_0)) # output(5, 30, 64)
根据前文,output的shape为:
output(batch_size, seq_len, num_directions * hidden_size) = output(5, 30, 64)
全连接层的定义为:
self.linear = nn.Linear(self.hidden_size=64, self.output_size=1)
因此,我们需要将output的第二维度变换为64(150, 64):
output = output.contiguous().view(self.batch_size * seq_len, self.hidden_size) # (5 * 30, 64)
然后将output送入全连接层:
pred = self.linear(output) # pred(150, 1)
得到的预测值shape为(150, 1)。我们需要将其进行还原,变成(5, 30, 1):
pred = pred.view(self.batch_size, seq_len, -1) # (5, 30, 1)
在用DataLoader处理了数据后,得到的input_seq和label的shape分别为:
input_seq(batch_size, seq_len) = input_seq(5, 30)label(batch_size, output_size) = label(5, 1)
由于输出是输入右移,我们只需要取pred第二维度(time)中的最后一个数据:
pred = pred[:, -1, :] # (5, 1)
这样,我们就得到了预测值,然后与label求loss,然后再反向更新参数即可。
时间序列预测的一个真实案例请见:PyTorch搭建LSTM实现时间序列预测(负荷预测)
来源:https://blog.csdn.net/Cyril_KI/article/details/122557880
猜你喜欢
- 当我们的函数接收参数为任意个,或者不能确定参数个数时,我们,可以利用 * 来定义任意数目的参数,这个函数调用时,其所有不匹配的位置
- 中文文本中可能出现的标点符号来源比较复杂,通过匹配等手段对他们处理的时候需要格外小心,防止遗漏。以下为在下处理中文标点的时候采用的两种方法:
- 本文实例为大家分享了python实现通讯录管理系统的具体代码,供大家参考,具体内容如下=====欢迎使用通讯录管理系统=====1.添加2.
- strip_tags定义和用法strip_tags() 函数剥去字符串中的 HTML、XML 以及 PHP 的标签。注释:该函数始终会剥离
- 当我们写用例断言时,往往一个断言结果是不够的,所以需要加入多重断言,而多重断言,当断言中间出现断言结果False时,会中断后续的断言执行,会
- 这是一个很简单的纯CSS相册滑动浏览效果,仅用一个无序列表ul结合简单的CSS就可以实现。原文中介绍的纵向滑动相册的实现方法,但是相比之下个
- 像在下拉菜单中选择省、市这样的操作,我一直用ASP来创建生成列表函数,把它们保存在一个Include文件中,用的时候就加载。这样做确实有个不
- 知识点: 函数 replicate 以下代码是实现如下功能: 代码如下:declare @sql varchar(200), --需填充的字
- 进入PyCharm后,点击File→Open,然后在弹窗中选择需要导入项目的文件夹;打开了python项目后,需要配置该项目对应的pytho
- Selenium简介Selenium是一个用于测试网站的自动化测试工具,支持各种浏览器包括Chrome、Firefox、Safari等主流界
- python面向对象编程入门,我们需要不断学习进步"""抽象工厂模式的实现"""
- 装饰器这东西我看了一会儿才明白,在函数外面套了一层函数,感觉和java里的aop功能很像;写了2个装饰器日志的例子,第一个是不带参数的装饰器
- 1.查看当前电脑python版本python -V // 显示2.7.x2.用brew升级pythonbrew update p
- 一、_func 单下划线开头 --口头私有变量1.1、在模块中使用单下划线开头在Python中,通过单下划线_来实现模块级别的私有化,变量除
- 这是 COMSHARP CMS 团队翻译的2009年海外Web设计风潮的第二部分,着重讲解了反 Box 式布局,单页布局,多栏布局,巨型插图
- 前言在本文中,您将学习如何使用 OpenCV 进行人脸识别。文章分三部分介绍:第一,将首先执行人脸检测,使用深度学习从每个人脸中提取人脸量化
- 前言利用Python的ffmpy库提取视频中的音频。本文提供工具类代码。环境依赖需要安装ffmpy,安装指令:pip install ffm
- 自己用python写了一个签到脚本,经过测试已经可以成功打卡,于是研究了一下windows定时运行程序1. 创建定时任务1.1 计划任务打开
- re.search():匹配整个字符串,并返回第一个成功的匹配。如果匹配失败,则返回None pattern: 匹配的规则,str
- <?php session_start(); $_SESSION['username']="zhuzhao&