使用Keras构造简单的CNN网络实例
作者:tina_ttl 发布时间:2023-08-23 04:38:21
1. 导入各种模块
基本形式为:
import 模块名
from 某个文件 import 某个模块
2. 导入数据(以两类分类问题为例,即numClass = 2)
训练集数据data
可以看到,data是一个四维的ndarray
训练集的标签
3. 将导入的数据转化我keras可以接受的数据格式
keras要求的label格式应该为binary class matrices,所以,需要对输入的label数据进行转化,利用keras提高的to_categorical函数
label = np_utils.to_categorical(label, numClass
此时的label变为了如下形式
(注:PyCharm无法显示那么多的数据,所以下面才只显示了1000个数据,实际上该例子所示的数据集有1223个数据)
4. 建立CNN模型
以下图所示的CNN网络为例
#生成一个model
model = Sequential()
#layer1-conv1
model.add(Convolution2D(16, 3, 3, border_mode='valid',input_shape=data.shape[-3:]))
model.add(Activation('tanh'))#tanh
# layer2-conv2
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
# layer3-conv3
model.add(Convolution2D(32, 3, 3, border_mode='valid'))
model.add(Activation('tanh'))#tanh
# layer4
model.add(Flatten())
model.add(Dense(128, init='normal'))
model.add(Activation('tanh'))#tanh
# layer5-fully connect
model.add(Dense(numClass, init='normal'))
model.add(Activation('softmax'))
#
sgd = SGD(l2=0.1,lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy', optimizer=sgd,class_mode="categorical")
5. 开始训练model
利用model.train_on_batch或者model.fit
补充知识:keras 多分类一些函数参数设置
用Lenet-5 识别Mnist数据集为例子:
采用下载好的Mnist数据压缩包转换成PNG图片数据集,加载图片采用keras图像预处理模块中的ImageDataGenerator。
首先import所需要的模块
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import MaxPooling2D,Input,Convolution2D
from keras.layers import Dropout, Flatten, Dense
from keras import backend as K
定义图像数据信息及训练参数
img_width, img_height = 28, 28
train_data_dir = 'dataMnist/train' #train data directory
validation_data_dir = 'dataMnist/validation'# validation data directory
nb_train_samples = 60000
nb_validation_samples = 10000
epochs = 50
batch_size = 32
判断使用的后台
if K.image_dim_ordering() == 'th':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
网络模型定义
主要注意最后的输出层定义
比如Mnist数据集是要对0~9这10种手写字符进行分类,那么网络的输出层就应该输出一个10维的向量,10维向量的每一维代表该类别的预测概率,所以此处输出层的定义为:
x = Dense(10,activation='softmax')(x)
此处因为是多分类问题,Dense()的第一个参数代表输出层节点数,要输出10类则此项值为10,激活函数采用softmax,如果是二分类问题第一个参数可以是1,激活函数可选sigmoid
img_input=Input(shape=input_shape)
x=Convolution2D(32, 3, 3, activation='relu', border_mode='same')(img_input)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)
x=Convolution2D(32,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)
x=Convolution2D(64,3,3,activation='relu',border_mode='same')(x)
x=MaxPooling2D((2,2),strides=(2, 2),border_mode='same')(x)
x = Flatten(name='flatten')(x)
x = Dense(64, activation='relu')(x)
x= Dropout(0.5)(x)
x = Dense(10,activation='softmax')(x)
model=Model(img_input,x)
model.compile(loss='binary_crossentropy',
optimizer='rmsprop',
metrics=['accuracy'])
model.summary()
利用ImageDataGenerator传入图像数据集
注意用ImageDataGenerator的方法.flow_from_directory()加载图片数据流时,参数class_mode要设为‘categorical',如果是二分类问题该值可设为‘binary',另外要设置classes参数为10种类别数字所在文件夹的名字,以列表的形式传入。
train_datagen = ImageDataGenerator(
rescale=1. / 255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical', #多分类问题设为'categorical'
classes=['0','1','2','3','4','5','6','7','8','9'] #十种数字图片所在文件夹的名字
)
validation_generator = test_datagen.flow_from_directory(
validation_data_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical'
)
训练和保存模型及权值
model.fit_generator(
train_generator,
samples_per_epoch=nb_train_samples,
nb_epoch=epochs,
validation_data=validation_generator,
nb_val_samples=nb_validation_samples
)
model.save_weights('Mnist123weight.h5')
model.save('Mnist123model.h5')
至此训练结束
图片预测
注意model.save()可以将模型以及权值一起保存,而model.save_weights()只保存了网络权值,此时如果要进行预测,必须定义有和训练出该权值所用的网络结构一模一样的一个网络。
此处利用keras.models中的load_model方法加载model.save()所保存的模型,以恢复网络结构和参数。
from keras.models import load_model
from keras.preprocessing.image import img_to_array, load_img
import numpy as np
classes=['0','1','2','3','4','5','6','7','8','9']
model=load_model('Mnist123model.h5')
while True:
img_addr=input('Please input your image address:')
if img_addr=="exit":
break
else:
img = load_img(img_addr, False, target_size=(28, 28))
x = img_to_array(img) / 255.0
x = np.expand_dims(x, axis=0)
result = model.predict(x)
ind=np.argmax(result,1)
print('this is a ', classes[ind])
来源:https://blog.csdn.net/tina_ttl/article/details/51034821


猜你喜欢
- forms组件django框架提供了一个Form类,来进行web开发中的表单提交数据的处理工作。导入相关模块from django impo
- 目录简介开发工具实现代码爬取效果Github地址:简介使用Python Tkinter开发一个爬取B站直播弹幕的工具,启动后在弹窗中输入房间
- 先说一下JS的获取方法,其要比JQUERY的方法麻烦很多,后面以JQUERY的方法作对比JS的方法会比JQUERY麻烦很多,主要则是因为FF
- 卸载MySQL1、在控制面板,卸载MySQL的所有组件控制面板——》所有控制面板项——》程序和功能,卸载所有和MySQL有关的程序2、找到你
- 严格来说,Having并不需要一个子表,但没有子表的Having并没有实际意义。如果你只需要一个表,那么你可以用Where子句达到一切目的。
- 关于数据库性能的故事面试时多多少少会讲到数据库上的事情,“你对数据库的掌握如何?”,什么时候最考验数据库的性能,答应主要方面上讲就是大数据量
- 表达式的优先级表达式(Expression)是运算符(operator)和操作数(operand)所构成的序列代码段a = 1b = 2c
- 当我们修改一份代码的时候,也许会碰到修改后的代码还不如修改之前的代码能够满足自己的需求,那么这个时候我们就需要对代码进行回滚,下面我们来看一
- 在Python中可以存储很大的值,如下面的Python示例程序:x = 1000000000000000000000000000000000
- lambda 函数Python 支持一种有趣的语法,它允许你快速定义单行的最小函数。这些叫做 lambda 的函数,是从 Lisp 借用来的
- 说明1、Task是Future的子类,Task是对协程的封装,我们把多个Task放在循环调度列表中,等待调度执行。2、Task对象可以跟踪任
- TextRank 是一种基于 PageRank 的算法,常用于关键词提取和文本摘要。在本文中,我将通过一个关键字提取示例帮助您了解 Text
- QThread是Qt的线程类中最核心的底层类。由于PyQt的的跨平台特性,QThread要隐藏所有与平台相关的代码要使用的QThread开始
- 我们利用linux系统中yum安装Apache+MySQL+PHP是非常的简单哦,只需要几步就可以完成,具体如下:一、脚本YUM源安装:1.
- 前言第一次用mysql,打开mysql的图形化界面要连接时,出现2003错误。究其原因,可能是mysql的服务没有启动。本文章主要围绕这个解
- Golang有很多第三方包,其中的 viper 支持读取多种配置文件信息。本文只是做一个小小demo,用来学习入门用的。1、安装go get
- 目录1、功能介绍2、关键代码2.1 主页功能2.2 添加商品信息2.3 数据库设计商品表前言:  
- 本文实例讲述了PHP模拟asp中response类的方法。分享给大家供大家参考。具体如下:习惯了asp或是asp.net开发的人, 他们会经
- 代码如下:SELECT * FROM Orders WHERE OrderGUID IN('BC71D821-9E25-
- 写SQL语句的时候很多时候会用到filter筛选掉一些记录,SQL对筛选条件简称:SARG(search argument/SARG) wh