DenseNet121模型实现26个英文字母识别任务
作者:实力 发布时间:2023-08-22 13:15:22
一、任务概述
26个英文字母识别是一个基于计算机视觉的图像分类任务,旨在从包含26个不同字母图像的数据集中训练一个深度学习模型,以对输入的字母图像进行准确的分类预测。本文将使用DenseNet121模型实现该任务。
二、DenseNet介绍
DenseNet是一种用于图像分类的深度学习架构,它的核心思想是通过连接前一层所有特征图到当前层来增强信息流,从而使得网络更深,更准确。相比于传统的卷积神经网络架构(如AlexNet和VGG),DenseNet具有更少的参数,更好的模型泛化能力和更高的效率。
DenseNet的网络结构类似于ResNet,由多个密集块(Dense Block)组成,其中每个密集块都是由多个卷积层和批量归一化层组成。与ResNet不同的是,DenseNet中每一层的输入都包含前面所有层的输出,这种密集连接方式可以避免信息瓶颈和梯度消失问题,促进了信息的传递和利用。同时,DenseNet还引入了过渡层(Transition Layer)来调整特征图的大小,减少计算量和内存占用。DenseNet最终通过全局平均池化层和softmax输出层生成预测结果。
三、数据集介绍
在本任务中,我们使用EMNIST数据集中的26个大写字母图像来训练和测试模型,它们是由28x28像素大小的手写字符图片构成。该数据集包含340,000张图像,其中240,000张用于训练,60,000张用于验证和40,000张用于测试。
四、模型实现
在这里我们将使用TensorFlow2.0框架中的Keras库来实现模型。首先需要导入所需的库和模块。
import numpy as np
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.keras.layers import Input, concatenate
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split
from PIL import Image
接着,定义一些超参数,例如batch_size、num_classes、epochs等。
batch_size = 128 # 批量大小
num_classes = 26 # 分类数目
epochs = 50 # 训练轮数
其次,加载EMNIST数据集。这里我们需要将数据集文件解压到指定路径,并读取所有图像和标签。
# 加载数据集
def load_dataset(path):
with np.load(path) as data:
X_train = data['X_train']
y_train = data['y_train']
X_test = data['X_test']
y_test = data['y_test']
return (X_train, y_train), (X_test, y_test)
# 加载数据集并进行归一化处理
def preprocess_data(X_train, y_train, X_test, y_test):
# 将图像矩阵归一化到0-1之间
X_train = X_train.astype('float32') / 255.
X_test = X_test.astype('float32') / 255.
# 将标签矩阵转换为one-hot编码
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
return X_train, y_train, X_test, y_test
# 加载训练和测试数据
(X_train_val, y_train_val), (X_test, y_test) = load_dataset('/data/emnist/mnist.npz')
# 划分训练集和验证集
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val,
test_size=0.2, random_state=42)
# 对数据进行归一化处理
X_train, y_train, X_val, y_val = preprocess_data(X_train, y_train, X_val, y_val)
X_test, y_test = preprocess_data(X_test, y_test, [], [])
在数据预处理后,我们需要定义DenseNet121模型。
# 定义dense_block函数
def dense_block(x, blocks, growth_rate):
for i in range(blocks):
x1 = BatchNormalization()(x)
x1 = Conv2D(growth_rate * 4, (1, 1), padding='same', activation='relu',
kernel_initializer='he_normal')(x1)
x1 = BatchNormalization()(x1)
x1 = Conv2D(growth_rate, (3, 3), padding='same', activation='relu',
kernel_initializer='he_normal')(x1)
x = concatenate([x, x1])
return x
# 定义transition_layer函数
def transition_layer(x, reduction):
x = BatchNormalization()(x)
x = Conv2D(int(x.shape.as_list()[-1] * reduction), (1, 1), activation='relu',
kernel_initializer='he_normal')(x)
x = MaxPooling2D((2, 2), strides=(2, 2))(x)
return x
# 构建DenseNet网络
def DenseNet(input_shape, num_classes, dense_blocks=3, dense_layers=-1,
growth_rate=12, reduction=0.5, dropout_rate=0.0, weight_decay=1e-4):
# 指定初始通道数和块数
depth = dense_blocks * dense_layers + 2
in_channels = 2 * growth_rate
inputs = Input(shape=input_shape)
# 第一层卷积
x = Conv2D(in_channels, (3, 3), padding='same', use_bias=False,
kernel_initializer='he_normal')(inputs)
# 堆叠密集块和过渡层
for i in range(dense_blocks):
x = dense_block(x, dense_layers, growth_rate)
in_channels += growth_rate * dense_layers
if i != dense_blocks - 1:
x = transition_layer(x, reduction)
# 全局平均池化
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = GlobalAveragePooling2D()(x)
# 输出层
outputs = Dense(num_classes, activation='softmax',
kernel_initializer='he_normal')(x)
# 定义模型
model = Model(inputs=inputs, outputs=outputs, name='DenseNet')
return model
# 构建DenseNet121网络
model = DenseNet(input_shape=(28, 28, 1), num_classes=num_classes, dense_blocks=3,
dense_layers=4, growth_rate=12, reduction=0.5, dropout_rate=0.0,
weight_decay=1e-4)
# 指定优化器、损失函数和评价指标
opt = Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
# 输出模型概况
model.summary()
在模型定义后,我们可以开始训练模型,使用EarlyStopping策略进行早停并保留最佳模型。
# 定义早停策略
earlystop = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=5,
verbose=1, mode='auto', restore_best_weights=True)
# 训练模型
history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs,
verbose=1, validation_data=(X_val, y_val), callbacks=[earlystop])
最后,我们可以对模型进行测试,并计算准确率等指标。
# 对模型进行评估
score = model.evaluate(X_test, y_test, verbose=0)
# 计算各项指标
test_accuracy = score[1]
print('Test accuracy:', test_accuracy)
# 保存模型
model.save('densenet121.h5')
五、实验结果与分析
使用上述代码,在EMNIST数据集上训练DenseNet121模型,输入28x28像素的字母图像,输出26种字母类别,并在测试集上评估最终性能。结果表明,该模型在测试集上达到96%以上的分类准确率,证明其较好的泛化能力和鲁棒性。
六、总结
本文介绍了基于DenseNet121模型实现26个英文字母识别任务的方法,主要涉及数据预处理、模型定义及训练、评估等步骤。DenseNet具有可解释性强、计算复杂度低等优点,能够有效提高模型精度和速度。值得注意的是,实际应用中还需要调整模型超参数、优化数据集和模型结构等方面,以进一步提升模型性能和普适性。
来源:https://juejin.cn/post/7225127360445005882
猜你喜欢
- 前言发现本站没有一个靠谱的tp6记录行为日志的教程,于是就整理了一下自己在项目中已经投入使用的行为日志中间件的详细配置步骤供大家参考提示:先
- 微信的小程序是一个很不错的体验,简单,上手快,这几天也在学习使用小程序,自己总结了三种用 Python 作为小程序后端的方式,供你参考。方法
- 1. # 可以使用LaTeX表示数学公式# 可以使用LaTeX表示数学公式from IPython.display import Latex
- 我就废话不多说了,大家还是直接看代码吧~'''Created on 2018-4-16'''
- 1、关于参数的区别实例方法:定义实例方法是最少有一个形参 ---> 实例对象,通常用 self类方法:定义类方法的时候最少有一个形参
- 1、python代码实现图片分割成九宫格需要包含的库,没有下载安装的,需要自己安装哦。实现原理很简单,就是用PIL库不断画小区域,切下来存储
- 在看到7yue博客——“换手来用”的思考 有这么一句话:RIA是一个更趋向于“体验”设计的领域,不仅仅包括“开发人员”,还包括“设计人员”,
- 作为前端开发工程师,平时对于Dom的查找遍历和操作是家常便饭。对于优秀的前端来说,也肯定早已有了自己的一套方法来封装这些重复的操作。但是,现
- 前言人类都是视觉动物,不管是男生还是女生看到漂亮的小姐姐、小哥哥就想截图保存下来。可是截图会对画质会产生损耗,截取的画面不规整,像素不高等问
- python中有的df列比较长head的时候会出现省略号,现在数据分析常用的就是基于anaconda的notebook和sypder,在sp
- 开始制作符合标准的站点,第一件事情就是声明符合自己需要的DOCTYPE。查看本站首页原代码,可以看到第一行就是:<!DOCTYPE h
- 安装完python之后,我们可以做两件事情,1.将安装目录中的Doc目录下的python331.chm使用手册复制到桌面上,方便学习和查阅2
- 前言最近因为工作需要要使用PHP 7,所以从网上找教程进行安装, 结果编译没问题, 安装的时候报了错误。错误如下cp -pR -f phar
- 一、读写txt文件1、打开txt文件file_handle=open('1.txt',mode='w')上述
- grid()函数概述grid()函数用于设置绘图区网格线。grid()的函数签名为matplotlib.pyplot.grid(b=None
- 本文实例讲述了Python实现将HTML转成PDF的方法。分享给大家供大家参考,具体如下:主要使用的是wkhtmltopdf的Python封
- 在IE 浏览器中使用 jquery的fadeIn() 效果 英文字符字体加粗的解决方法分享。<div id='tes
- 认识模块对于模块,在前面的一些举例中,已经涉及到了,比如曾经有过:import random (获取随机数模块)。为了能够对模块有一个清晰的
- 工作中最常见的配置文件有四种:普通key=value的配置文件、Json格式的配置文件、HTML格式的配置文件以及YMAML配置文件。这其中
- 1、引言小 * 丝:鱼哥,你说百度翻译的准确,还是google翻译的准确?小鱼:自己翻译的最准确。小 * 丝:你这… 抬杠。小