Tensorflow2.1实现Fashion图像分类示例详解
作者:我是王大你是谁 发布时间:2021-01-28 03:59:52
实现思路和详细解读
1. 获取 Fashion 数据、处理数据
(1)本次实践项目用到的是 Fashion 数据集,包含 10 个类别的服饰灰度图片,共 70000 张,每张图片展示一件衣物,每张图片都是低分辨率的 28x28 像素(其实就是28*28的整数矩阵)。部分效果如下:
本次实践中我们对数据集进行训练集和测试集的划分,使用 60000 张图片来训练模型,使用 10000 张图片来评估模型对服饰图片的分类的准确程度。
直接通过 tensorflow 内置的接口函数从网络上下载数据集,其中 (train_images, train_labels) 分别是训练集中的图片和标签,(test_images, test_labels) 分别是测试集中的图片和标签。
fashion = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion.load_data()
(2)这里主要是使用列表来保存数据集中出现过的所有衣服种类的名字。
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
(3)定义一个函数,主要用来显示某张照片。
def showImage(image):
plt.figure()
plt.imshow(image, cmap=plt.cm.binary)
plt.colorbar()
plt.grid(False)
plt.show()
(4)这里主要展示训练集中第一张原始图片的样子,我们可以看到每张图片都是低分辨率的 28x28 像素(其实就是28*28的矩阵数字),而且每个像素是在 0-255 之间的数字
showImage(train_images[0])
效果如下:
因为每张图片的每个像素都是 0-255 之间的整数,所以我们为了模型训练的加快收敛,将所有的图片都进行归一化操作。
train_images = train_images / 255.0
test_images = test_images / 255.0
这里主要展示训练集中第一张原始图片的样子经过归一化操作的结果,可以看到每个像素点都是 0-1 之间的小数。
showImage(train_images[0])
效果如下:
2. 使用 tensorflow 2.1 搭建模型
(1)因为每张图片的输入是 28*28 的像素点,所以第一层的是输入设置为 input_shape=(28, 28) ,输出的是一个 784 维的向量,该操作可以看作是将 input_shape 多维数组中的值,重新拼接到一起整合成了一个一维数组。
(2)第二层、第三层都是通过激活函数 relu 的非线性变化,输出一个 64 维向量的全连接层操作,当然这个网络结构的层数、激活函数、每层的输出维度可以自行随意调整,其大小会影响最后的模型评估的指标,理论上结构越复杂效果越好,但是训练速度越慢,而且这也会引起过拟合的现象,这个度的把握需要不断通过输出的指标来进行调整。
(3)第三层是输出一个 10 维度的全连接层操作,其实就是该输入图片分别属于这十个类别的对应的概率分布。
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
3. 配置并编译模型
(1)这里我们选择了 Adam 优化器,这是一个比较成熟且广泛使用的优化器。
(2)这里的损失函数我们选择了比较常见的交叉熵 SparseCategoricalCrossentropy 。
(3)这里我们选用了最为常用的模型评估指标准确率 accuracy。
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
4. 训练模型
(1)我们使用训练数据的图片和标签来对模型进行训练,并且设置 epochs 为 5 ,也就是将所有的训练集从头到尾反复训练 5 次,如果模型还没有收敛那你也可以把 epochs 的值设置的大一些,配合较为复杂的网络结构,最后模型在训练阶段的准确率应该能达到 98% 以上。
(2)在模型的训练期间,命令行会显示模型整体损失值和准确率评估指标的情况,这些都是 tensorflow 内部函数写好的输出格式,你也可以自己写代码改变。
model.fit(train_images, train_labels, epochs=5)
训练过程输出如下所示:
Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 5s 78us/sample - loss: 0.5140 - accuracy: 0.8180
Epoch 2/5
60000/60000 [==============================] - 4s 73us/sample - loss: 0.3724 - accuracy: 0.8654
Epoch 3/5
60000/60000 [==============================] - 4s 74us/sample - loss: 0.3388 - accuracy: 0.8763
Epoch 4/5
60000/60000 [==============================] - 4s 70us/sample - loss: 0.3165 - accuracy: 0.8831
Epoch 5/5
60000/60000 [==============================] - 4s 74us/sample - loss: 0.2985 - accuracy: 0.8902
5. 评估模型
(1)这里我们使用测试数据来对模型进行评估,评估的指标也就是之前规定的准确率。
(2)verbose=2 只是为了规定结果输出形式,可以选择 0、1、2 中的任意一个 。
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print('Test accuracy:%f, Test loss:%f'%(acc, loss))
结果输出如下所示,表明所训练的模型在测试集的评估下准确率能达到 0.8725 :
10000/1 - 1s - loss: 0.2279 - accuracy: 0.8725
Test accuracy:0.872500, Test loss:0.352148
6. 使用模型进行预测
(1)上面模型在预测每张图片时最后输出的都是一个 10 维的杂乱无章的浮点数数组,为了保证输出的结果是被人容易理解的,我们在上面模型的的最后加上了一层 Softmax 。
(2)Softmax 的功能很简单,就是将这 10 个杂乱无章的浮点数,转化成 10 个概率,每个概率值在 0-1 之间,这 10 个概率的和为 1 ,这样我们取其中最大概率的值对应的衣服类别作为我们预测的结果。
model_by_softmax = tf.keras.Sequential([model, tf.keras.layers.Softmax()])
predictions = model_by_softmax.predict(test_images)
predict_for_one_image = predictions[3]
predict_for_one_image
输出概率分布如下所示:
array([4.1939876e-07, 9.9996161e-01, 8.4085507e-08, 3.7719459e-05,
3.1557637e-08, 1.0500006e-13, 1.5945717e-07, 1.3569163e-13,
1.8028586e-09, 4.5183642e-11], dtype=float32)
(3)查看一张测试图片在这 10 个衣服类别中的概率分布,我们发现 9.9996161e-01 最大,表明该图片成为第二类衣服款式的概率最大。
(4)我们可以输出预测的衣服种类,通过查找最大概率对应的衣服种类,发现该图片被预测的衣服类别是 Trouser 。
class_names[np.argmax(predict_for_one_image)]
输出结果如下:
'Trouser'
(5)我们将原始图片绘制出来发现确实是一条裤子,说明模型预测正确。
showImage(test_images[3])
展示效果如下:
来源:https://juejin.cn/post/7164994875732000782


猜你喜欢
- 1 介绍在设计到数据库的开发中,难免要将图片或音频文件插入到数据库中的情况。一般来说,我们可以同过插入图片文件相应的存储位置,而不是文件本身
- 如何使用Iframe实现本页提交?例:chunfeng.html< html>< head>&n
- SQL 多条件查询以后我们做多条件查询,一种是排列结合,另一种是动态拼接SQL如:我们要有两个条件,一个日期@addDate,一个是@nam
- 前言为了便于精准排查问题,需要将当前的请求信息与当前执行的 SQL 信息设置对应关系记录下来,记录的 SQL 信息包括:执行 SQL 的当前
- 为什么要实现分页?在大部分网站中分页的功能都是必要的,尤其是在后台管理中分页更是不可或缺分页能带给用户更好的体验,也能减轻服务器的压力对于分
- 原因是:It looks like you need to flush stdout periodically (e.g. sys.stdo
- 关于递归函数:函数内部调用自身的函数。以n阶乘为例:f(n) = n ! = 1 x 2 x 3 x 4 x...x(n-1)x(n) =
- 添加jar包 官网下载jar包idea导入jar包:检查官网下载jar包官网地址:MySQL :: Download Connec
- 单表操作增加数据auther_obj = {"auther_name":"崔皓然","au
- 什么是MySql数据库?通常意义上,数据库也就是数据的集合,具体到计算机上数据库可以是存储器上一些文件的集合或者一些内存数据的集合。我们通常
- 一、Python安装Window系统下,python的安装很简单。访问python.org/download,下载最新版本,安装过程与其他w
- 我们可向函数传递动态参数,*args,**kwargs,首先我们来看*args,示例如下:1.show(*args)def show(*ar
- 如下所示:#-*- coding: utf-8 -*-import pandas as pdimport numpy as npfrom p
- eval() 函数用来执行一个字符串表达式,并返回表达式的值。eval函数功能:将字符串str当成有效的表达式来求值并返回计算结果。eval
- 前言Django项目本身就可以启动运行,为什么还需要部署到Apache或者Nginx上呢?初学者都会遇到这个问题,我们来看看官方解释:It&
- 第三方库 binarytree其使用环境、安装方法及二叉树的相关知识,请见:《Python 初识二叉树,新手也秒懂!》不能导入的请安装:pi
- 使用Python内置函数:bin()、oct()、int()、hex()可实现进制转换。先看Python官方文档中对这几个内置函数的描述:b
- vue封装常用工具类公司要新开一个项目,我来分享一下简单封装常用的工具类首先在util目录下创建一个Common.js文件然后开始封装1.验
- 1,exists和in的理解exists:如果子查询中包括某一行,那么就为TRUE in:如果操作数为TRUE等于表达式列表中的一个,那么就
- 一、相关代码数据库配置类 MongoDBConn.py#encoding=utf-8'''Mongo Conn连接类