python神经网络Keras构建CNN网络训练
作者:Bubbliiiing 发布时间:2022-08-16 08:42:37
利用Keras构建完普通BP神经网络后,还要会构建CNN
Keras中构建CNN的重要函数
1、Conv2D
Conv2D用于在CNN中构建卷积层,在使用它之前需要在库函数处import它。
from keras.layers import Conv2D
在实际使用时,需要用到几个参数。
Conv2D(
nb_filter = 32,
nb_row = 5,
nb_col = 5,
border_mode = 'same',
input_shape = (28,28,1)
)
其中,nb_filter代表卷积层的输出有多少个channel,卷积之后图像会越来越厚,这就是卷积后图像的厚度。nb_row和nb_col的组合就是卷积器的大小,这里卷积器是(5,5)的大小。border_mode代表着padding的方式,same表示卷积前后图像的shape不变。input_shape代表输入的shape。
2、MaxPooling2D
MaxPooling2D指的是池化层,在使用它之前需要在库函数处import它。
from keras.layers import MaxPooling2D
在实际使用时,需要用到几个参数。
MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same'
)
其中,pool_size表示池化器的大小,在这里,池化器的shape是(2,2)。strides是池化器的步长,这里在X和Y方向上都是2,池化后,输出比输入的shape小了1/2。border_mode代表着padding的方式。
3、Flatten
Flatten用于将卷积池化后最后的输出变为一维向量,这样才可以和全连接层连接,用于计算。在使用前需要用import导入。
from keras.layers import Flatten
在实际使用时,在最后一个池化层后直接添加层即可
model.add(Flatten())
全部代码
这是一个卷积神经网络的例子,用于识别手写体,其神经网络结构如下:
卷积层1->池化层1->卷积层2->池化层2->flatten->全连接层1->全连接层2->全连接层3。
单个样本的shape如下:
(28,28,1)->(28,28,32)->(14,14,32)->(14,14,64)->(7,7,64)->(3136)->(1024)->(256)
import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation,Conv2D,MaxPooling2D,Flatten ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
X_train = X_train.reshape(-1,28,28,1)
X_test = X_test.reshape(-1,28,28,1)
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
model = Sequential()
# conv1
model.add(
Conv2D(
nb_filter = 32,
nb_row = 5,
nb_col = 5,
border_mode = 'same',
input_shape = (28,28,1)
)
)
model.add(Activation("relu"))
# pool1
model.add(
MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same'
)
)
# conv2
model.add(
Conv2D(
nb_filter = 64,
nb_row = 5,
nb_col = 5,
border_mode = 'same'
)
)
model.add(Activation("relu"))
# pool2
model.add(
MaxPooling2D(
pool_size = (2,2),
strides = (2,2),
border_mode = 'same'
)
)
# 全连接层
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
model.add(Dense(256))
model.add(Activation("relu"))
model.add(Dense(10))
model.add(Activation("softmax"))
adam = Adam(lr = 1e-4)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = adam,metrics = ['accuracy'])
## tarin
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
## acc
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
实验结果为:
Epoch 1/2
60000/60000 [==============================] - 64s 1ms/step - loss: 0.7664 - acc: 0.9224
Epoch 2/2
60000/60000 [==============================] - 62s 1ms/step - loss: 0.0473 - acc: 0.9858
Test
10000/10000 [==============================] - 2s 169us/step
accuracy: 0.9856
来源:https://blog.csdn.net/weixin_44791964/article/details/101171576


猜你喜欢
- 引言经过函数学习之后我们会发现函数不被调用是不会直接执行的,我们在之前的函数调用之后发现运行的结果都是函数体内print()打印出来的结果,
- 下面通过文字说明和代码分析的方式给大家分享移动端图片上传之localResizeIMG先压缩后ajax无刷新上传,具体实现过程请看下文。现在
- 业务场景:在后台管理系统表格模块中,我们请求回来的数据类似性别等等,后台给我们返的不是男,或者女,而是给我们返回的是0和1,或者是A和B;但
- 前言多线程一般用于同时调用多个函数,cpu时间片轮流分配给多个任务。 优点是提高cpu的使用率,使计算机减少处理多个任务的总时间;缺点是如果
- 想做个和IBM公司一样的网站LOGO,试了半天也没有做出来,郁闷之下,只好求高手帮助!先在这里谢谢了!方法一1、写上IBM,调节字号颜色2、
- 效果图:作用:将页面中的电话号码生成图片格式。<%Public Sub Com_CreatValidCode(pT
- 本文为大家分享了python利用高阶函数实现剪枝函数的具体代码,供大家参考,具体内容如下案例:  
- 工作中我们经常需要判断某个变量/属性是否为undefined。通常有两种写法// 方式1 typeof age === 'undef
- 1. 什么是阻塞队列?阻塞队列(BlockingQueue)是一个支持两个附加操作的队列。这两个附加的操作是:在队列为空时,获取元素的线程会
- 1. 前言数组和矩阵是数值计算的基础元素。目前为止,我们都是使用NumPy的ndarray数据结构来表示数组,这是一种同构的容器,用于存储数
- 自从我用 Python 编写第一行代码以来,我就被它的简单性、出色的可读性和流行的单行代码所吸引。在下文中,我想介绍和解释其中的一些单行代码
- jinja2简介特征沙箱中执行强大的 HTML 自动转义系统保护系统免受 XSS模板继承及时编译最优的 python 代码可选提前编译模板的
- Cookie 对象是一种以文件(Cookie文件)的形式保存在客户端硬盘的Cookies文件夹中的数据信息(Cookie数据)。Cookie
- 点云生成 3D 网格的最快方法已经用 Python 编写了几个实现来从点云中获取网格。它们中的大多数
- 一、需求来源工作中需要一种树形菜单组件,经过两天的构思最终通过作用域插槽实现: 此组件将每个节点(插槽名为 node)暴露出来。通过插槽的
- 如下所示:pd.to_datetime(data[data['last_O_XLMC']==data['O_XLMC
- 本文实例讲述了MySQL触发器简单用法。分享给大家供大家参考,具体如下:mysql触发器和存储过程一样,是嵌入到mysql的一段程序,触发器
- 今天帮朋友做个python的小工具,发现系统上缺少ptyhon的支持库,返回如下信息ImportError: No module named
- 今天,我们来分享一个宠物桌面小程序,全程都是通过 PyQT 来制作的,对于 Python GUI 感兴趣的朋友,千万不要错过哦!我们先来看看
- 一、判断数据库表是否存在: 首先要拿到数据库连接conn,调用DatabaseMetaData dbmd = conn.getDataMet