python神经网络使用Keras构建RNN训练
作者:Bubbliiiing 发布时间:2021-07-19 21:12:15
标签:python,神经网络,Keras,RNN,训练
Keras中构建RNN的重要函数
1、SimpleRNN
SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。
from keras.layers import SimpleRNN
在实际使用时,需要用到几个参数。
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。
2、model.train_on_batch
与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
具体训练过程如下:
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
x = X_test[1].reshape(1,28,28)
全部代码
这是一个RNN神经网络的例子,用于识别手写体。
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
实验结果为:
10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698
来源:https://blog.csdn.net/weixin_44791964/article/details/101609556
0
投稿
猜你喜欢
- 代码如下: EXEC sp_rename '表名.[原列名]', '新列名', 'column
- 本文给出了几个表单常用的js验证函数,有检查、\等特殊字符的,有检查是否含有空格,检查是否为Email 地址,也有检查是否是小数或负数的,检
- 本文介绍了数据库索引,及其优、缺点。针对MySQL索引的特点、应用进行了详细的描述。分析了如何避免MySQL无法使用,如何使用EXPLAIN
- 在运用xmlhttp组件编写程序中,会碰到 "msxml3.dll 错误 ‘800c0005’&nb
- Turtle库是Python语言中一个很流行的绘制图像的函数库,想象一个小乌龟,在一个横轴为x、纵轴为y的坐标系原点,(0,0)
- Python 中的 timeit 模块可以用来测试一段代码的执行耗时,如一个变量赋值语句的执行时间,一个函数的运行时间等。timeit 模块
- 文档介绍利用python写“猜数字”,“猜词语”,“谁是卧底”这三个游戏,从而快速掌握python编程的入门知识,包括python语法/列表
- 看了下网上有很多关于模拟登录淘宝,但是基本都是使用scrapy、pyppeteer、selenium等库来模拟登录,但是目前我们还没有讲到这
- 1. composer 安装 PDF组件composer require setasign/fpdicomposer require set
- 异步操作数据的方式有两种常见的方式:XMLHttpRequest 和 iframe. 孰优孰劣在此我们不争论,只是想举一个例子说明在获取网片
- 1.首先主题选择不要落俗!现在许多的个人主页就象“大锅饭”。题材包罗万象,内容雷同无味。人人都是“软件速递”“音乐宝库”“主页教程”等等。让
- 我想让一片文章,每到3000字就分到下一条插入到数据库,求高手 <%Dim Content Conte
- 本文实例讲述了php指定长度分割字符串str_split函数用法。分享给大家供大家参考,具体如下:示例1:$str = 'abcde
- python实现阶乘-基础版本什么是阶乘呢?在数学运算中n!表示n的阶乘,用数学公式表示为:n!=1*2*3*....*(n-1)*n下面提
- 作者:AngelGavin 出处:CSDN一般问题什么是 XML?可扩展标记语言 (XML) 是 Web 上的数据通用语言。它使
- 定义列表和其他类型的列表稍有不同,它由两部分组成:名称和定义。DT 指定名称,为内联元素。DD 指定定义,为块级元素。标准属性id, cla
- 为index.php文件设置只读属性后,木马就没权限给你文件末尾追加广告了。下面我们看具体的代码,设置index.php只读:<?ph
- 如何在php中判断一个网页请求是ajax请求还是普通请求?你可以通过传递参数的方法来实现,例如使用如下网址请求:/path/to/pkphp
- 有时候,规划师(或需求、交互)把内容呈现的框架草图搭建好后,就直接“丢”给了设计师,让设计师在画好的框架里去美化内容,出来后的效果,往往达不
- replace方法的语法是:stringObj.replace(rgExp, replaceText) 其中stringObj是字符串(st