DenseNet121模型实现26个英文字母识别任务
作者:实力 发布时间:2023-08-22 13:15:22
一、任务概述
26个英文字母识别是一个基于计算机视觉的图像分类任务,旨在从包含26个不同字母图像的数据集中训练一个深度学习模型,以对输入的字母图像进行准确的分类预测。本文将使用DenseNet121模型实现该任务。
二、DenseNet介绍
DenseNet是一种用于图像分类的深度学习架构,它的核心思想是通过连接前一层所有特征图到当前层来增强信息流,从而使得网络更深,更准确。相比于传统的卷积神经网络架构(如AlexNet和VGG),DenseNet具有更少的参数,更好的模型泛化能力和更高的效率。
DenseNet的网络结构类似于ResNet,由多个密集块(Dense Block)组成,其中每个密集块都是由多个卷积层和批量归一化层组成。与ResNet不同的是,DenseNet中每一层的输入都包含前面所有层的输出,这种密集连接方式可以避免信息瓶颈和梯度消失问题,促进了信息的传递和利用。同时,DenseNet还引入了过渡层(Transition Layer)来调整特征图的大小,减少计算量和内存占用。DenseNet最终通过全局平均池化层和softmax输出层生成预测结果。
三、数据集介绍
在本任务中,我们使用EMNIST数据集中的26个大写字母图像来训练和测试模型,它们是由28x28像素大小的手写字符图片构成。该数据集包含340,000张图像,其中240,000张用于训练,60,000张用于验证和40,000张用于测试。
四、模型实现
在这里我们将使用TensorFlow2.0框架中的Keras库来实现模型。首先需要导入所需的库和模块。
import numpy as np
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import Input, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from PIL import Image
接着,定义一些超参数,例如batch_size、num_classes、epochs等。
batch_size = 128 # 批量大小
num_classes = 26 # 分类数目
epochs = 50 # 训练轮数
其次,加载EMNIST数据集。这里我们需要将数据集文件解压到指定路径,并读取所有图像和标签。
# 加载数据集
def load_dataset(path):
with np.load(path) as data:
X_train = data['X_train']
y_train = data['y_train']
X_test = data['X_test']
y_test = data['y_test']
return (X_train, y_train), (X_test, y_test)
# 加载数据集并进行归一化处理
def preprocess_data(X_train, y_train, X_test, y_test):
# 将图像矩阵归一化到0-1之间
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
# 将标签矩阵转换为one-hot编码
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
return X_train, y_train, X_test, y_test
# 加载训练和测试数据
(X_train_val, y_train_val), (X_test, y_test) = load_dataset('/data/emnist/mnist.npz')
# 划分训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val,
test_size=0.2, random_state=42)
# 对数据进行归一化处理
X_train, y_train, X_val, y_val = preprocess_data(X_train, y_train, X_val, y_val)
X_test, y_test = preprocess_data(X_test, y_test, [], [])
在数据预处理后,我们需要定义DenseNet121模型。
# 定义dense_block函数
def dense_block(x, blocks, growth_rate):
for i in range(blocks):
x1 = BatchNormalization()(x)
x1 = Conv2D(growth_rate * 4, (1, 1), padding='same', activation='relu',
kernel_initializer='he_normal')(x1)
x1 = BatchNormalization()(x1)
x1 = Conv2D(growth_rate, (3, 3), padding='same', activation='relu',
kernel_initializer='he_normal')(x1)
x = concatenate([x, x1])
return x
# 定义transition_layer函数
def transition_layer(x, reduction):
x = BatchNormalization()(x)
x = Conv2D(int(x.shape.as_list()[-1] * reduction), (1, 1), activation='relu',
kernel_initializer='he_normal')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
return x
# 构建DenseNet网络
def DenseNet(input_shape, num_classes, dense_blocks=3, dense_layers=-1,
growth_rate=12, reduction=0.5, dropout_rate=0.0, weight_decay=1e-4):
# 指定初始通道数和块数
depth = dense_blocks * dense_layers + 2
in_channels = 2 * growth_rate
inputs = Input(shape=input_shape)
# 第一层卷积
x = Conv2D(in_channels, (3, 3), padding='same', use_bias=False,
kernel_initializer='he_normal')(inputs)
# 堆叠密集块和过渡层
for i in range(dense_blocks):
x = dense_block(x, dense_layers, growth_rate)
in_channels += growth_rate * dense_layers
if i != dense_blocks - 1:
x = transition_layer(x, reduction)
# 全局平均池化
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
# 输出层
outputs = Dense(num_classes, activation='softmax',
kernel_initializer='he_normal')(x)
# 定义模型
model = Model(inputs=inputs, outputs=outputs, name='DenseNet')
return model
# 构建DenseNet121网络
model = DenseNet(input_shape=(28, 28, 1), num_classes=num_classes, dense_blocks=3,
dense_layers=4, growth_rate=12, reduction=0.5, dropout_rate=0.0,
weight_decay=1e-4)
# 指定优化器、损失函数和评价指标
opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
# 输出模型概况
model.summary()
在模型定义后,我们可以开始训练模型,使用EarlyStopping策略进行早停并保留最佳模型。
# 定义早停策略
earlystop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5,
verbose=1, mode='auto', restore_best_weights=True)
# 训练模型
history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs,
verbose=1, validation_data=(X_val, y_val), callbacks=[earlystop])
最后,我们可以对模型进行测试,并计算准确率等指标。
# 对模型进行评估
score = model.evaluate(X_test, y_test, verbose=0)
# 计算各项指标
test_accuracy = score[1]
print('Test accuracy:', test_accuracy)
# 保存模型
model.save('densenet121.h5')
五、实验结果与分析
使用上述代码,在EMNIST数据集上训练DenseNet121模型,输入28x28像素的字母图像,输出26种字母类别,并在测试集上评估最终性能。结果表明,该模型在测试集上达到96%以上的分类准确率,证明其较好的泛化能力和鲁棒性。
六、总结
本文介绍了基于DenseNet121模型实现26个英文字母识别任务的方法,主要涉及数据预处理、模型定义及训练、评估等步骤。DenseNet具有可解释性强、计算复杂度低等优点,能够有效提高模型精度和速度。值得注意的是,实际应用中还需要调整模型超参数、优化数据集和模型结构等方面,以进一步提升模型性能和普适性。
来源:https://juejin.cn/post/7225127360445005882


猜你喜欢
- 此BUG最初是在《前端观察》网站刊登,这里再描述一下,代码如下:<style>*{ padding:0; m
- 本文实例代码主要实现的是python遍历文件目录的操作,有三种方法,具体代码如下。#coding:utf-8 # 方法1:递归遍历目录 im
- 一、使用sklearn转换器处理sklearn提供了model_selection模型选择模块、preprocessing数据预处理模块、d
- 使用sql的计划任务可以处理一些特殊环境的数据,除了使用windows系统的计划任务来定时处理,不过要配合程序才行,有些事情可以直接使用sq
- 一、基本概念(查询语句)①基本语句1、“select * from 表名;”,—
- 本文实例讲述了Python编程之序列操作。分享给大家供大家参考,具体如下:#coding=utf8''''&
- 1、前言 MySQL 是完全网络化的跨平台关系型数据库系统,同时是具有客户机/服务器体系结构的分布式数据库管理系统。它具有功能强、使用简便、
- 使用 types 增强vscode中javascript代码提示功能微软的vscode编辑器是开发typescript项目的不二首选,其本身
- 1。注意用SQL分析器可以看select出来的东西select right(convert(varchar(30),getdate(),12
- 1.安装数据库1)yum -y install mysql-server(简单)yum命令自动从网上寻找mysql服务资源,下载至本地并完成
- 1. 简介 追踪某些软件运行时所发生事件的方法, 可以在代码中调用日志中某些方法来记录发生的事情一个事件可以用一个可包含可选变量数
- 遇到了这个问题,意思是你的 CPU 支持AVX AVX2 (可以加速CPU计算),但你安装的 TensorFlow 版本不支持解决:1. 如
- 问题:我想每日从数据库里导出一些数据,内容基本上都是一样的,只是时间不同,比如导出一张表wjzcreate table wjz(id int
- 1、简单的按钮js事件 用于判断和显示提示 <script type="text/javascript&
- sqlserver2008不支持关键字limit ,所以它的分页sql查询语句将不能用MySQL的方式进行,幸好sqlserver2008提
- 对我当前工程进行全部测试需要花费不少时间。既然有 26 GB 空闲内存,为何不让其发挥余热呢? tmpfs 可以通过把文件系统保
- 前言Go 数组的长度不可改变,在特定场景中这样的集合就不太适用,Go中提供了一种灵活,功能强悍的内置类型切片("动态数组"
- 本文实例讲述了django+js+ajax实现刷新页面的方法。分享给大家供大家参考,具体如下:在服务器开发的时候,为了方便将服务器对外开一个
- 序列化模块import pickle序列化和反序列化把不能直接存储的数据变得可存储,这个过程叫做序列化。把文件中的数据拿出来,回复称原来的数
- Python是一门非常酷的语言,因为很少的Python代码可以在短时间内做很多事情,并且,Python很容易就能支持多任务和多重处理。py&