python之tensorflow手把手实例讲解斑马线识别实现
作者:鑫xing 发布时间:2021-11-11 05:53:19
一,斑马线的数据集
数据集的构成:
test | train |
---|---|
zebra corssing:56 | zebra corssing:168 |
other:54 | other:164 |
二,代码部分
1.导包
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
import keras
2.数据导入
train_dir=r'C:\Users\zx\深度学习\Zebra\train'
test_dir=r'C:\Users\zx\深度学习\Zebra\test'
train_datagen = ImageDataGenerator(rescale=1/255,
rotation_range=10, #旋转
horizontal_flip=True)
train_generator = train_datagen.flow_from_directory(train_dir,
(50,50),
batch_size=1,
class_mode='binary',
shuffle=False)
test_datagen = ImageDataGenerator(rescale=1/255)
test_generator = test_datagen.flow_from_directory(test_dir,
(50,50),
batch_size=1,
class_mode='binary',
shuffle=False)
3.搭建模型
模型的建立仁者见智,可自己调节寻找更好的模型。
model = tf.keras.models.Sequential([
# 第一层卷积,卷积核为,共16个,输入为150*150*1
tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(50,50,3)),
tf.keras.layers.MaxPooling2D((2,2)),
# 第二层卷积,卷积核为3*3,共32个,
tf.keras.layers.Conv2D(32,(3,3),activation='relu'),
tf.keras.layers.MaxPooling2D((2,2)),
# 第三层卷积,卷积核为3*3,共64个,
tf.keras.layers.Conv2D(64,(3,3),activation='relu'),
tf.keras.layers.MaxPooling2D((2,2)),
# 第四层卷积,卷积核为3*3,共128个
# tf.keras.layers.Conv2D(128,(3,3),activation='relu'),
# tf.keras.layers.MaxPooling2D((2,2)),
# 数据铺平
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(32,activation='relu'),
tf.keras.layers.Dense(16,activation='relu'),
tf.keras.layers.Dense(2,activation='softmax')
])
print(model.summary())
model.compile(optimize='adam',
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['acc'])
4,模型训练
history = model.fit(train_generator,
epochs=20,
verbose=1)
model.save('./Zebra.h5')
模型训练过程:
可以看到我们的模型在20轮的训练后acc从0.63上升到了0.96左右。
5,模型评估
model.evaluate(test_generator)
#可视化
plt.plot(history.history['acc'], label='accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.7, 1])
plt.legend(loc='lower right')
plt.title('acc')
plt.show()
6,模型预测
虽然我们的模型在训练过程中acc一度达到0.96,但测试集才是检验模型的唯一标准,在model.evaluate(test_generator)中的评分只有0.91左右,说明我们的模型已经能以很高的正确率来完成”斑马线“与“非斑马线”的二分类问题了,但我们还是要查看具体是哪些数据没有被模型正确得识别。
pred=model.predict(test_generator) #获取test集的输出
filenames = test_generator.filenames #获取test数据的文件名
错误输出过程:
1,循环测试集长度,通过if语句先判断others还是zebra,再通过one-hot编码判断是否预测正确。
2,根据labels可知others': 0, 'zebra crossing': 1,以此来判断是否预测正确。
3,对 filenames[0]='others\\103.png',进行切片处理。
4,找到others的‘s'或 zebra crossing的‘g',使用find()在基础上+2为正切片的起点(样本编号前有'\'符号,故+2才能正确取出编号)。
5,如 :将filenames[i]的值赋给a,a[int(a.find('s')+2):]则表示为 'xx.png'。
6,将取出的样本编号与路径拼接,读取后作图。
7,break跳出循环。
for i in range(len(filenames)):
if filenames[i][:6]=='others':
if np.argmax(pred[i]) != 0:
a=filenames[i]
plt.figure()
print('预测错误的图片:'+a[int(a.find('s')+2):])
print('错误识别为"zebra crossing",正确类型是"others"')
print('预测标签为:'+str(np.argmax(pred[i]))+',真实标签为:0')
img = plt.imread('Zebra/test/others/'+a[int(a.find('s')+2):])
plt.imshow(img)
plt.title(a[int(a.find('s')+2):])
plt.grid(False)
break
if filenames[i][:6]=='zebra ':
if np.argmax(pred[i]) != 1:
b= filenames[i]
plt.figure()
print('预测错误的图片:'+b[int(b.find('g')+2):])
print('错误识别为"others",正确类型是"zebra crossing"')
print('预测标签为:'+str(np.argmax(pred[i]))+',真实标签为:1')
img = plt.imread('Zebra/test/zebra crossing/'+b[int(b.find('g')+2):])
plt.imshow(img)
plt.title(b[int(b.find('g')+2):])
plt.grid(False)
break
看到这个错误样本,我猜想可能是因为斑马线的部分只占了图像的一半左右,所以预测错误了。
这里是我做预测判断的思路,本可以不这么复杂的可以用test_generator.labels来获取数据的标签,再做判断。
test_generator.labels
上面只输出了第一个错误的样本,所以接下来我们要看所有错误预测的样本
sum=0
for i in range(len(filenames)):
if filenames[i][:6]=='others':
if np.argmax(pred[i]) != 0:
a=filenames[i]
print('预测错误的图片:'+a[int(a.find('s')+2):]+',错误识别为"zebra crossing",正确类型是"others"')
sum=sum+1
if filenames[i][:6]=='zebra ':
if np.argmax(pred[i]) != 1:
b= filenames[i]
print('预测错误的图片:'+b[int(b.find('g')+2):]+',错误识别为"others",正确类型是"zebra crossing"')
sum=sum+1
print('错误率:'+str(sum/100)+'%')
print('正确率:'+str((10000-sum)/100)+'%')
三,分析
在构建模型时我尝试在最后一层只用一个神经元,用sigmoid激活函数,其他参数不变,在同样epochs=20的条件,也能很快收敛,达到很高的acc,测试集的评分也能在0.9左右,但是在最后输出全部错误样本的时候发现错误的样本远超过softmax,可能其中有些参数我没有根据sigmoid来调整,所以会有如此高的错误率,欢迎在评论区讨论。
来源:https://blog.csdn.net/weixin_52612318/article/details/117375576


猜你喜欢
- Pytorch的数据类型为各式各样的Tensor,Tensor可以理解为高维矩阵。与Numpy中的Array类似。Pytorch中的tens
- 本文实例讲述了mysql设置指定ip远程访问连接的方法,分享给大家供大家参考。具体实现方法如下:1. 授权用户root使用密码jb51从任意
- Python2.7已于2020年1月1日开始停用,之前RF做自动化都是基于Python2的版本。没办法,跟随时代的脚步,我们也不得不升级以应
- 如何在庞大的数据中高效的检索自己需要的东西?本篇内容介绍了Python做出一个大数据搜索引擎的原理和方法,以及中间进行数据分析的原理也给大家
- 本文实例讲述了python基于queue和threading实现多线程下载的方法,分享给大家供大家参考。具体方法如下:主代码如下: &nbs
- 今天做站时碰到个小问题:ASP正则获取文章内容图片地址,现在将此方法的思路拿出来分享下:Function RegExp_Execu
- 已经有很多年不使用SQLServer了,毕竟商业版本是个收费的,安装也不容易。最近因为想带领学生学习做个练习性的项目,参考了.net下的pe
- 写一个学生管理系统,最好用python。我都没学过python呢,只好开始临时抱佛脚,再到网上找找有没有例子看看,下面是我参照另一个博主写的
- 本文实例讲述了Python爬虫实现网页信息抓取功能。分享给大家供大家参考,具体如下:首先实现关于网页解析、读取等操作我们要用到以下几个模块i
- 一、re是什么?正则表达式是一个特殊的字符序列,能方便的检查一个字符串是否与某种模式匹配。re模块使得python拥有全部的正则表达式功能。
- #/usr/bin/env python#-*- coding:utf-8 -*-"""1.解析 cronta
- 功能super功能:super函数是子类用于调用父类(超类)的一个方法。用法1.在子类 __init__() 方法中正确的初始化父类,保证相
- 文件下载1.通过a标签点击直接下载<a href="https:xxx.xlsx" rel="exter
- Oracle公司6月9日宣布同意收购TimesTen公司,TimesTen是一家私营软件企业,其产品能提高用于股市和机票预订等需要快速响应时
- 前言文件上传漏洞大多出现在可以进行文件上传的地方,如用户头像上传,文档上传处等。该漏洞是一个危害十分大的漏洞,通过文件上传,攻击者可以上传w
- 一.windows系统的解决方法1.首先以系统管理员身份登陆系统。2.停止MySQL的服务。3.进入命令窗口,然后进入MySQL的安装目录,
- Python关于删除list中的某个元素,一般有两种方法,pop()和remove()。remove() 函数用于移除列表中某个值的第一个匹
- 首先,有个单例对象,它上面挂了很多静态工具方法。其中有一个是each,用来遍历数组或对象。var nativeForEach = [].fo
- 本文实例讲述了用python实现面向对像的ASP程序的方法。分享给大家供大家参考。具体实现方法如下:平时我们写ASP时,一般都用vbscri
- 今天看了一下数据结构的书,发现其实数据结构没有几种,线性表,数组,字符串,队列和栈,等等,其实是一回事,然后就是树结构,图结构。数据结构的理