Python+SimpleRNN实现股票预测详解
作者:别团等shy哥发育 发布时间:2022-12-04 13:08:26
原理请查看前面几篇文章。
1、数据源
SH600519.csv 是用 tushare 模块下载的 SH600519 贵州茅台的日 k 线数据,本次例子中只用它的 C 列数据(如图 所示):
用连续 60 天的开盘价,预测第 61 天的开盘价。
2、代码实现
按照六步法: import 相关模块->读取贵州茅台日 k 线数据到变量 maotai,把变量 maotai 中前 2126 天数据中的开盘价作为训练数据,把变量 maotai 中后 300 天数据中的开盘价作为测试数据;然后对开盘价进行归一化,使送入神经网络的数据分布在 0 到 1 之间;
接下来建立空列表分别用于接收训练集输入特征、训练集标签、测试集输入特征、测试集标签;
继续构造数据。用 for 循环遍历整个训练数据,每连续60 天数据作为输入特征 x_train,第 61 天数据作为对应的标签 y_train ,一共生成 2066 组训练数据,然后打乱训练数据的顺序并转变为 array 格式继而转变为 RNN 输入要求的维度;
同理,利用 for 循环遍历整个测试数据,一共生成 240组测试数据,测试集不需要打乱顺序,但需转变为 array 格式继而转变为 RNN 输入要求的维度。
用 sequntial 搭建神经网络:
第一层循环计算层记忆体设定 80 个,每个时间步推送 h t h_t ht给下一层,使用 0.2 的 Dropout;
第二层循环计算层设定记忆体有 100 个,仅最后的时间步推送 h t h_t ht给下一层,使用 0.2 的 Dropout;
由于输出值是第 61 天的开盘价只有一个数,所以全连接 Dense 是 1->compile 配置训练方法使用 adam 优化器,使用均方误差损失函数。在股票预测代码中,只需观测 loss,训练迭代打印的时候也只打印 loss,所以这里就无需给metrics赋值->设置断点续训,fit 执行训练过程->summary 打印出网络结构和参数统计。
进行 loss 可视化与参数报错操作
进行股票预测。用 predict 预测测试集数据,然后将预测值和真实值从归一化的数值变换到真实数值,最后用红色线画出真实值曲线 、用蓝色线画出预测值曲线。
为了评价模型优劣,给出了三个评判指标:均方误差、均方根误差和平均绝对误差,这些误差越小说明预测的数值与真实值越接近。
RNN 股票预测 loss 曲线:
RNN 股票预测曲线:
RNN 股票预测评价指标:
模型摘要:
3、完整代码
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dropout, Dense, SimpleRNN
import matplotlib.pyplot as plt
import os
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math
# 读取股票文件
maotai = pd.read_csv('./SH600519.csv')
# 前(2426-300=2126)天的开盘价作为训练集,表格从0开始计数,2:3 是提取[2:3)列,前闭后开,故提取出C列开盘价
training_set = maotai.iloc[0:2426 - 300, 2:3].values
# 后300天的开盘价作为测试集
test_set = maotai.iloc[2426 - 300:, 2:3].values
# 归一化
sc = MinMaxScaler(feature_range=(0, 1)) # 定义归一化:归一化到(0,1)之间
training_set_scaled = sc.fit_transform(training_set) # 求得训练集的最大值,最小值这些训练集固有的属性,并在训练集上进行归一化
test_set = sc.transform(test_set) # 利用训练集的属性对测试集进行归一化
x_train = []
y_train = []
x_test = []
y_test = []
# 测试集:csv表格中前2426-300=2126天数据
# 利用for循环,遍历整个训练集,提取训练集中连续60天的开盘价作为输入特征x_train,第61天的数据作为标签,for循环共构建2426-300-60=2066组数据。
for i in range(60, len(training_set_scaled)):
x_train.append(training_set_scaled[i - 60:i, 0])
y_train.append(training_set_scaled[i, 0])
# 对训练集进行打乱
np.random.seed(7)
np.random.shuffle(x_train)
np.random.seed(7)
np.random.shuffle(y_train)
tf.random.set_seed(7)
# 将训练集由list格式变为array格式
x_train, y_train = np.array(x_train), np.array(y_train)
# 使x_train符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]。
# 此处整个数据集送入,送入样本数为x_train.shape[0]即2066组数据;输入60个开盘价,预测出第61天的开盘价,循环核时间展开步数为60; 每个时间步送入的特征是某一天的开盘价,只有1个数据,故每个时间步输入特征个数为1
x_train = np.reshape(x_train, (x_train.shape[0], 60, 1))
# 测试集:csv表格中后300天数据
# 利用for循环,遍历整个测试集,提取测试集中连续60天的开盘价作为输入特征x_test,第61天的数据作为标签y_test,for循环共构建300-60=240组数据。
for i in range(60, len(test_set)):
x_test.append(test_set[i - 60:i, 0])
y_test.append(test_set[i, 0])
# 测试集变array并reshape为符合RNN输入要求:[送入样本数, 循环核时间展开步数, 每个时间步输入特征个数]
x_test, y_test = np.array(x_test), np.array(y_test)
x_test = np.reshape(x_test, (x_test.shape[0], 60, 1))
model = tf.keras.Sequential([
SimpleRNN(80, return_sequences=True),# 第一层循环计算层:记忆体设定80个,每个时间步推送ht给下一层
Dropout(0.2), #使用0.2的Dropout
SimpleRNN(100),# 第二层循环计算层,设定记忆体100个
Dropout(0.2), #
Dense(1) # 由于输出值是第61天的开盘价,只有一个数,所以Dense是1
])
model.compile(optimizer=tf.keras.optimizers.Adam(0.001),
loss='mean_squared_error') # 损失函数用均方误差
# 该应用只观测loss数值,不观测准确率,所以删去metrics选项,一会在每个epoch迭代显示时只显示loss值
checkpoint_save_path = "./checkpoint/rnn_stock.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True,
monitor='val_loss')
history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
file = open('./weights.txt', 'w') # 参数提取
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
################## predict ######################
# 测试集输入模型进行预测
predicted_stock_price = model.predict(x_test)
# 对预测数据还原---从(0,1)反归一化到原始范围
predicted_stock_price = sc.inverse_transform(predicted_stock_price)
# 对真实数据还原---从(0,1)反归一化到原始范围
real_stock_price = sc.inverse_transform(test_set[60:])
# 画出真实数据和预测数据的对比曲线
plt.plot(real_stock_price, color='red', label='MaoTai Stock Price')
plt.plot(predicted_stock_price, color='blue', label='Predicted MaoTai Stock Price')
plt.title('MaoTai Stock Price Prediction')
plt.xlabel('Time')
plt.ylabel('MaoTai Stock Price')
plt.legend()
plt.show()
##########evaluate##############
# calculate MSE 均方误差 ---> E[(预测值-真实值)^2] (预测值减真实值求平方后求均值)
mse = mean_squared_error(predicted_stock_price, real_stock_price)
# calculate RMSE 均方根误差--->sqrt[MSE] (对均方误差开方)
rmse = math.sqrt(mean_squared_error(predicted_stock_price, real_stock_price))
# calculate MAE 平均绝对误差----->E[|预测值-真实值|](预测值减真实值求绝对值后求均值)
mae = mean_absolute_error(predicted_stock_price, real_stock_price)
print('均方误差: %.6f' % mse)
print('均方根误差: %.6f' % rmse)
print('平均绝对误差: %.6f' % mae)
来源:https://blog.csdn.net/qq_43753724/article/details/124874848
猜你喜欢
- 本文实例讲述了Python二叉搜索树与双向链表实现方法。分享给大家供大家参考,具体如下:# encoding=utf8''&
- editTable.js 提供编辑表格当前行、添加一行、删除当前行的操作,其中可以设置参数,如:operatePos 用于设置放置操作的列,
- 当一张的数据达到几百万时,你查询一次所花的时间会变多,如果有联合查询的话,我想有可能会死在那儿了。分表的目的就在于此,减小数据库的负担,缩短
- 在这里我想有必要再较系统说一下ADO的各种对象的方法、属性。毕竟ADO不仅应用在ASP中,VB,VC都可以用到。在这十天中我想主要提到的对象
- 栅格系统的形成1692年,新登基的法国国王路易十四感到法国的印刷水平强差人意,因此命令成立一个管理印刷的皇家特别委员会。他们的首要任务是设计
- Wingdings字体,Symbol字体<html> <head> <title>
- 项目信号处理和提取部分用到了matlab,需要应用到工程中方便研究。用具有万能粘合剂之称的“Python”。具体方法如下:1.python中
- 机器视觉从Google的无人驾驶汽车到可以识别假钞的自动售卖机,机器视觉一直都是一个应用广泛且具有深远的影响和雄伟的愿景的领域。这里我们将重
- 1. 返回列表和标量(Scalar)前面我们注意到Query对象可以返回可迭代的值(iterator value),然后我们可以通过for
- 如下所示:# coding = GBKa =[1,2,3,4,5]sum=0b = len(a)print("这个数组的长度为:&
- 使用python批量修改文本文件编码格式把文本文件的编码格式进行批量幻化,比如ascii, gb2312, utf8等,相互转化,字符集的大
- 在我们使用查询语句的时候,经常要返回前几条或者中间某几行数据,这个时候怎么办呢?不用担心, mysql已经为我们提供了这样一个功
- 前言:字体反爬是什么个意思?就是网站把自己的重要数据不直接的在源代码中呈现出来,而是通过相应字体的编码,与一个字体文件(一般后缀为ttf或w
- 本文实例讲述了kNN算法python实现和简单数字识别的方法。分享给大家供大家参考。具体如下:kNN算法算法优缺点:优点:精度高、对异常值不
- YUI3.2.0 的 transition 模块,通过使用 transition:end 事件实现在 transition 完成后执行其他操
- 要将xian80地理坐标系转换成投影坐标系:xian1980 = """GEOGCS["GCS_Xi
- JavaScript是运行在客户端的脚本,因此一般是不能够设置Session的,因为Session是运行在服务器端的。而cookie是运行在
- asp防止用户同时登陆的方法,实现这个功能可有两种方式:1.使用application用application对象:如果做的是大型社区,可能
- 在进行接口自动化测试时,有好多接口都基于登陆接口的响应值来关联进行操作的,在次之前试了很多方法,都没有成功,其实很简单用session来做。
- Rs.GetRows(N):N代表获取记录数量 Rs.GetRows(1):1表示只返回一行记录 Rs.GetRows(-1):-1表示默认