python神经网络Keras实现LSTM及其参数量详解
作者:Bubbliiiing 发布时间:2023-02-09 14:02:22
什么是LSTM
1、LSTM的结构
我们可以看出,在n时刻,LSTM的输入有三个:
当前时刻网络的输入值Xt;
上一时刻LSTM的输出值ht-1;
上一时刻的单元状态Ct-1。
LSTM的输出有两个:
当前时刻LSTM输出值ht;
当前时刻的单元状态Ct。
2、LSTM独特的门结构
LSTM用两个门来控制单元状态cn的内容:
遗忘门(forget gate),它决定了上一时刻的单元状态cn-1有多少保留到当前时刻;
输入门(input gate),它决定了当前时刻网络的输入c’n有多少保存到新的单元状态cn中。
LSTM用一个门来控制当前输出值hn的内容:
输出门(output gate),它利用当前时刻单元状态cn对hn的输出进行控制。
3、LSTM参数量计算
a、遗忘门
遗忘门这里需要结合ht-1和Xt来决定上一时刻的单元状态cn-1有多少保留到当前时刻;
由图我们可以得到,我们在这一环节需要计一个参数ft。
b、输入门
输入门这里需要结合ht-1和Xt来决定当前时刻网络的输入c’n有多少保存到单元状态cn中。
由图我们可以得到,我们在这一环节需要计算两个参数,分别是it。
和C’t
里面需要训练的参数分别是Wi、bi、WC和bC。
在定义LSTM的时候我们会使用到一个参数叫做units,其实就是神经元的个数,也就是LSTM的输出——ht的维度。
所以:
c、输出门
输出门利用当前时刻单元状态cn对hn的输出进行控制;
由图我们可以得到,我们在这一环节需要计一个参数ot。
里面需要训练的参数分别是Wo和bo。在定义LSTM的时候我们会使用到一个参数叫做units,其实就是神经元的个数,也就是LSTM的输出——ht的维度。所以:
d、全部参数量
所以所有的门总参数量为:
在Keras中实现LSTM
LSTM一般需要输入两个参数。
一个是unit、一个是input_shape。
LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))
unit用于指定神经元的数量。
input_shape用于指定输入的shape,分别指定TIME_STEPS和INPUT_SIZE。
实现代码
import numpy as np
from keras.models import Sequential
from keras.layers import Input,Activation,Dense
from keras.models import Model
from keras.datasets import mnist
from keras.layers.recurrent import LSTM
from keras.utils import np_utils
from keras.optimizers import Adam
TIME_STEPS = 28
INPUT_SIZE = 28
BATCH_SIZE = 50
index_start = 0
OUTPUT_SIZE = 10
CELL_SIZE = 75
LR = 1e-3
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28)/255
X_test = X_test.reshape(-1,28,28)/255
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
inputs = Input(shape=[TIME_STEPS,INPUT_SIZE])
x = LSTM(CELL_SIZE, input_shape = (TIME_STEPS,INPUT_SIZE))(inputs)
x = Dense(OUTPUT_SIZE)(x)
x = Activation("softmax")(x)
model = Model(inputs,x)
adam = Adam(LR)
model.summary()
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
for i in range(50000):
X_batch = X_train[index_start:index_start + BATCH_SIZE,:,:]
Y_batch = Y_train[index_start:index_start + BATCH_SIZE,:]
index_start += BATCH_SIZE
cost = model.train_on_batch(X_batch,Y_batch)
if index_start >= X_train.shape[0]:
index_start = 0
if i%100 == 0:
cost,accuracy = model.evaluate(X_test,Y_test,batch_size=50)
print("accuracy:",accuracy)
实现效果:
10000/10000 [==============================] - 3s 340us/step
accuracy: 0.14040000014007092
10000/10000 [==============================] - 3s 310us/step
accuracy: 0.6507000041007995
10000/10000 [==============================] - 3s 320us/step
accuracy: 0.7740999992191792
10000/10000 [==============================] - 3s 305us/step
accuracy: 0.8516999959945679
10000/10000 [==============================] - 3s 322us/step
accuracy: 0.8669999945163727
10000/10000 [==============================] - 3s 324us/step
accuracy: 0.889699995815754
10000/10000 [==============================] - 3s 307us/step
来源:https://blog.csdn.net/weixin_44791964/article/details/103896884
猜你喜欢
- Python最基本的数据结构是序列(列表/元组)。一个序列中的每个元素都分配有一个数字- 它的位置或索引。第一个索引是0,第二个
- 一、bs4解析import requestsfrom bs4 import BeautifulSoupimport datetimeif _
- 要求存在一个文件夹内有若干张图像,需要计算每张图片的RGB均值,并计算全部图像的RGB均值。代码# -*- coding: utf-8 -*
- 一个简单的验证码爬取程序本文介绍了在Python2.7环境下爬取网站验证码:思路就是获取验证码对应的url,然后发起requst请求,读取该
- 约定:import pandas as pdimport numpy as npfrom numpy import nan as NaN滤除
- 对于有的vps,系统默认安装了mysql。我们需要从我们的服务器、vps上卸载(移除)默认的mysql。那么如何(怎样)在ubuntu\De
- 一、类型1.变量没有类型,数据有类型例:num = 1 ---->num是没有类型的,1是int类型二、格式化输出2.na
- 前言大家应该都知道在编程语言中,定时任务是常用的一种调度形式,在Python中也涌现了非常多的调度模块,本文将简要介绍APScheduler
- Django默认Path转换器str:匹配任何非空字符串,但不含斜杠/,如果你没有专门指定转换器,那么这个是默认使用的;int:匹配0和正整
- 1、requests 的常见用法requests 除了 url 之外,还有 params, data 和 files 三个参数,用于和服务器
- 举例如下,一个服务器端的form 代码自动被解释成客户端代码:服务器端代码: &l
- 目录1. 前言2. 介绍及安装3. 实战一下3-1 创建爬虫项目3-2 创建爬虫 Ai
- 安装 iupdatable 包pip install iupdatableTimer类主要函数:获取 Unix 时间戳(精确到秒):time
- 最近利用tkinter+python+pyinstaller实现了小工具的项目,在此记录下pyinstaller相关参数以及爬过的坑。一、p
- 什么是条形图?条形图(bar chart)是用宽度相同的条形的高度或长短来表示数据多少的图形。条形图可以横置或纵置,纵置时也称为柱形图(co
- 一、subprocess以及常用的封装函数运行python的时候,我们都是在创建并运行一个进程。像Linux进程那样,一个进程可以fork一
- 本文实例介绍了php打包网站并在线压缩为zip的方法,分享给大家供大家参考,具体内容如下<?php//在URL后参加 ?pwd=密码
- 上班族经常会遇到这样情况,着急下班结果将关机误点成重启,或者临近下班又通知开会,开完会已经迟了还要去给电脑关机。今天使用PyQt5做了个自动
- 本文实例讲述了wxPython窗口的继承机制,分享给大家供大家参考。具体分析如下:示例代码如下:import wx class
- 一、mysqlcheck简介mysqlcheck客户端可以检查和修复MyISAM表。它还可以优化和分析表。mysqlcheck的功能类似my