python神经网络使用Keras构建RNN训练
作者:Bubbliiiing 发布时间:2021-07-19 21:12:15
标签:python,神经网络,Keras,RNN,训练
Keras中构建RNN的重要函数
1、SimpleRNN
SimpleRNN用于在Keras中构建普通的简单RNN层,在使用前需要import。
from keras.layers import SimpleRNN
在实际使用时,需要用到几个参数。
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
其中,batch_input_shape代表RNN输入数据的shape,shape的内容分别是每一次训练使用的BATCH,TIME_STEPS表示这个RNN按顺序输入的时间点的数量,INPUT_SIZE表示每一个时间点的输入数据大小。
CELL_SIZE代表训练每一个时间点的神经元数量。
2、model.train_on_batch
与之前的训练CNN网络和普通分类网络不同,RNN网络在建立时就规定了batch_input_shape,所以训练的时候也需要一定量一定量的传入训练数据。
model.train_on_batch在使用前需要对数据进行处理。获取指定BATCH大小的训练集。
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
具体训练过程如下:
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
x = X_test[1].reshape(1,28,28)
全部代码
这是一个RNN神经网络的例子,用于识别手写体。
import numpy as np
from keras.models import Sequential
from keras.layers import SimpleRNN,Activation,Dense ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
SimpleRNN(
batch_input_shape = (BATCH_SIZE,TIME_STEPS,INPUT_SIZE),
output_dim = CELL_SIZE,
)
)
model.add(Dense(OUTPUT_SIZE))
model.add(Activation("softmax"))
adam = Adam(LR)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
for i in range(500):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
## acc
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
实验结果为:
10000/10000 [==============================] - 1s 147us/step
accuracy: 0.09329999938607215
…………………………
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9395000022649765
10000/10000 [==============================] - 1s 109us/step
accuracy: 0.9422999995946885
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9534000000357628
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9566000008583069
10000/10000 [==============================] - 1s 113us/step
accuracy: 0.950799999833107
10000/10000 [==============================] - 1s 116us/step
10000/10000 [==============================] - 1s 112us/step
accuracy: 0.9474999988079071
10000/10000 [==============================] - 1s 111us/step
accuracy: 0.9515000003576278
10000/10000 [==============================] - 1s 114us/step
accuracy: 0.9288999977707862
10000/10000 [==============================] - 1s 115us/step
accuracy: 0.9487999993562698
来源:https://blog.csdn.net/weixin_44791964/article/details/101609556
0
投稿
猜你喜欢
- 一、new做了哪些事先看看new的使用场景:// 1、创建一个构造函数function Vehicle(name, price) { &nb
- 你在使用pandas处理DataFrame中是否遇到过如下这类问题?我们需要删除某一列所有元素中含有固定字符元素所在的行,比如下面的例子:&
- 1.configparser介绍configparser是python自带的配置参数解析器。可以用于解析.config文件中的配置参数。in
- 浏览器的开发者在很早的时候就已经意识到, HTTP's 的无状态会对Web开发者带来很大的问题,于是(cookies)应运而生。 c
- Python生成随机验证码,需要使用PIL模块,具体内容如下安装:pip3 install pillow基本使用1. 创建图片from PI
- 前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应
- 今天来说一下如何判断字典中是否存在某个key,一般有两种通用做法,下面为大家来分别讲解一下:第一种方法:使用自带函数实现。在python的字
- 本文实例讲述了Python基于xlrd模块操作Excel的方法。分享给大家供大家参考,具体如下:一、使用xlrd读取excel1、xlrd的
- 创建测试表-- ------------------------------ Table structure for check_test-
- 大家好,欢迎大家来到算法数据结构专题,今天我们和大家聊一个非常常用的算法,叫做LRU。LRU的英文全称是Least Recently Use
- 我们在设计网站的时候,有的时候需要根据页面元素的属性来制作不同的样式,比如,对于不同的链接类型,显示不同的链接图标。CSS的选择器是个很有用
- 概述在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如A,B,C,
- 前言:.net6LTS版本发布已经有若干天了。此处做一个关于使用.net6开发精简版webapi(minimalapi)的入门教程,以及VS
- 背景借助django-admin,可以快速得到CRUD界面,但若需要创建多选标签字段时,需要对表单进行调整示例model.py一个tag(标
- 一、数据集下载加州高速公路PEMS数据集这里绘制PEMS04中的交通流量数据。该数据集中包含旧金山2018年1月1日至2月28日的29条道路
- 不知道大家有没有一种感觉,每次当使用numpy数组的时候坐标轴总是傻傻分不清楚,然后就会十分的困惑,每次运算都需要去尝试好久才能得出想要的结
- 说明字符串驻留是一种仅保存一份相同且不可变字符串的方法。不同的值被存放在字符串驻留池中,发生驻留之后, 许多变量可能指向内存中的相同字符串对
- 本文介绍基于Python语言gdal模块,实现多波段HDF栅格图像文件的读取、处理与像元值可视化(直方图绘制)等操作。另外,基于gdal等模
- 近期做个小项目需要用到python读取图片,自己整理了一下两种读取图片的方式,其中一种用到了TensorFlow,(TensorFlow是基
- 本文实例讲述了python使用opencv实现马赛克效果。分享给大家供大家参考,具体如下:最近要实现opencv视频打马赛克,在网上找了一下