Pytorch 如何实现LSTM时间序列预测
作者:CodeInHand 发布时间:2023-06-26 01:04:24
开发环境说明:
Python 35
Pytorch 0.2
CPU/GPU均可
1、LSTM简介
人类在进行学习时,往往不总是零开始,学习物理你会有数学基础、学习英语你会有中文基础等等。
于是对于机器而言,神经网络的学习亦可不再从零开始,于是出现了Transfer Learning,就是把一个领域已训练好的网络用于初始化另一个领域的任务,例如会下棋的神经网络可以用于打德州扑克。
我们这讲的是另一种不从零开始学习的神经网络——循环神经网络(Recurrent Neural Network, RNN),它的每一次迭代都是基于上一次的学习结果,不断循环以得到对于整体序列的学习,区别于传统的MLP神经网络,这种神经网络模型存在环型结构,
具体下所示:
上图是RNN的基本单元,通过不断循环迭代展开模型如下所示,图中ht是神经网络的在t时刻的输出,xt是t时刻的输入数据。
这种循环结构对时间序列数据能够很好地建模,例如语音识别、语言建模、机器翻译等领域。
但是普通的RNN对于长期依赖问题效果比较差,当序列本身比较长时,由于神经网络模型的训练是采用backward进行,在梯度链式法则中容易出现梯度消失和梯度 * 的问题,需要进一步改进RNN的模型结构。
针对Simple RNN存在的问题,LSTM网络模型被提出,LSTM的核心是修改了增添了Cell State,即加入了LSTM CELL,通过输入门、输出门、遗忘门把上一时刻的hidden state和cell state传给下一个状态。
如下所示:
遗忘门:ft = sigma(Wf*[ht-1, xt] + bf)
输入门:it = sigma(Wi*[ht-1, xt] + bi)
cell state initial: C't = tanh(Wc*[ht-1, xt] +bc)
cell state: Ct = ft*Ct-1+ itC't
输出门:ot = sigma(Wo*[ht-1, xt] + bo)
模型输出:ht = ot*tanh(Ct)
LSTM有很多种变型结构,实际工程化过程中用的比较多的是peephole,就是计算每个门的时候增添了cell state的信息,有兴趣的童鞋可以专研专研。
上一部分简单地介绍了LSTM的模型结构,下边将具体介绍使用LSTM模型进行时间序列预测的具体过程。
2、数据准备
对于时间序列,本文选取正弦波序列,事先产生一定数量的序列数据,然后截取前部分作为训练数据训练LSTM模型,后部分作为真实值与模型预测结果进行比较。正弦波的产生过程如下:
SeriesGen(N)方法用于产生长度为N的正弦波数值序列;
trainDataGen(seq,k)用于产生训练或测试数据,返回数据结构为输入输出数据。seq为序列数据,k为LSTM模型循环的长度,使用1~k的数据预测2~k+1的数据。
3、模型构建
Pytorch的nn模块提供了LSTM方法,具体接口使用说明可以参见Pytorch的接口使用说明书。此处调用nn.LSTM构建LSTM神经网络,模型另增加了线性变化的全连接层Linear(),但并未加入激活函数。由于是单个数值的预测,这里input_size和output_size都为1.
4、训练和测试
(1)模型定义、损失函数定义
(2)训练与测试
(3)结果展示
比较模型预测序列结果与真实值之间的差距
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:{ourl}


猜你喜欢
- 前言:线性回归模型属于经典的统计学模型,该模型的应用场景是根据已知的变量(即自变量)来预测某个连续的数值变量(即因变量)。例如餐厅根据媒体的
- 本章是前一章的延续,我们使用RSA算法逐步实现加密,并详细讨论它.用于解密密文的函数是as跟随 :def decrypt(ciph
- 本文实现12306抢火车票/京东抢手机示例,具体如下:#12306秒抢Python代码from splinter.browser impor
- Redis通常被认为是一种持久化的存储器关键字-值型存储,可以用于几台机子之间的数据共享平台。连接数据库 注意:假设现有几台在同一局域网内的
- 做数据分析、科学计算等离不开工具、语言的使用,目前最流行的数据语言,无非是MATLAB,R语言,Python这三种语言,但今天小编简单总结了
- 旋转椭圆实例代码:import matplotlib.pyplot as pltimport numpy as npfrom matplot
- 登录百度AL开发平台在控制台选择语音合成创建应用填写应用信息在应用列表获取(Appid、API Key、Secret Key)6. 安装py
- 事实上,互联网用户浏览网页的习惯和顾客浏览商店中物品的习惯没有多大差别。用户打开一个新的页面,扫视一些文字,并点击第一个引起他兴趣的链接。在
- 刚开始使用webpack时,可能很多人都会有过这样的想法,在require文件时,能不能不写静态的字符串路径,而是使用一个更灵活的方式,比如
- 前言这里先说明一下,网上很多人说阿里规定500w数据就要分库分表。实际上,这个500w并不是定义死的,而是与MySQL的配置以及机器的硬件有
- 全文检索引擎入门灰常不幸的是,关系型数据库对全文检索的支持没有被标准化。不同的数据库通过它们自己的方式来实现全文检索,而且SQL
- 1.zip用法简介在python 3.x系列中,zip方法返回的为一个zip object可迭代对象。class zip(object):&
- 一、待搜索图二、测试集三、new_similarity_compare.py# -*- encoding=utf-8 -*-from ima
- 引言:最近邻插值Nearest Neighbour Interpolate算法是图像处理中普遍使用的图像尺寸缩放算法,由于其实现简单计算速度
- tensorflow版本1.4获取变量维度是一个使用频繁的操作,在tensorflow中获取变量维度主要用到的操作有以下三种:Tensor.
- 分享炫酷的前端页面随机二维码验证,供大家参考,具体内容如下直接上代码<%@ page contentType="text/h
- 这礼拜碰到一些问题,然后意识到基础知识一段时间没巩固的话,还是有遗忘的部分,还是需要温习,这里做份笔记,记录一下前续先简单描述下碰到的题目,
- 每位SQL Server开发员都有自己的首选操作方法。我的方法叫做分子查询。这些是由原子查询组合起来的查询,通过它们我可以处理一个表格。将原
- 说明:关于类的这部分,我参考了《Learning Python》一书的讲解。创建类创建类的方法比较简单,如下:class Person:&n
- 一、实验目标1、使用 K-means 模型进行聚类,尝试使用不同的类别个数 K,并分析聚类结果。2、按照 8:2 的比例随机将数据划分为训练