python之tensorflow手把手实例讲解猫狗识别实现
作者:鑫xing 发布时间:2021-12-26 14:46:11
标签:python,tensorflow,猫狗识别,卷积神经网络
目录
一,猫狗数据集数目构成
二,数据导入
三,数据集构建
四,模型搭建
五,模型训练
六,模型测试
作为tensorflow初学的大三学生,本次课程作业的使用猫狗数据集做一个二分类模型。
一,猫狗数据集数目构成
train | cats:1000 ,dogs:1000 |
---|---|
test | cats: 500,dogs:500 |
validation | cats:500,dogs:500 |
二,数据导入
train_dir = 'Data/train'
test_dir = 'Data/test'
validation_dir = 'Data/validation'
train_datagen = ImageDataGenerator(rescale=1/255,
rotation_range=10,
width_shift_range=0.2, #图片水平偏移的角度
height_shift_range=0.2, #图片数值偏移的角度
shear_range=0.2, #剪切强度
zoom_range=0.2, #随机缩放的幅度
horizontal_flip=True, #是否进行随机水平翻转
# fill_mode='nearest'
)
train_generator = train_datagen.flow_from_directory(train_dir,
(224,224),batch_size=1,class_mode='binary',shuffle=False)
test_datagen = ImageDataGenerator(rescale=1/255)
test_generator = test_datagen.flow_from_directory(test_dir,
(224,224),batch_size=1,class_mode='binary',shuffle=True)
validation_datagen = ImageDataGenerator(rescale=1/255)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,(224,224),batch_size=1,class_mode='binary')
print(train_datagen)
print(test_datagen)
print(train_datagen)
三,数据集构建
我这里是将ImageDataGenerator类里的数据提取出来,将数据与标签分别存放在两个列表,后面在转为np.array,也可以使用model.fit_generator,我将数据放在内存为了后续调参数时模型训练能更快读取到数据,不用每次训练一整轮都去读一次数据(应该是这样的…我是这样理解…)
注意我这里的数据集构建后,三种数据都是存放在内存中的,我电脑内存是16g的可以存放下。
train_data=[]
train_labels=[]
a=0
for data_train, labels_train in train_generator:
train_data.append(data_train)
train_labels.append(labels_train)
a=a+1
if a>1999:
break
x_train=np.array(train_data)
y_train=np.array(train_labels)
x_train=x_train.reshape(2000,224,224,3)
test_data=[]
test_labels=[]
a=0
for data_test, labels_test in test_generator:
test_data.append(data_test)
test_labels.append(labels_test)
a=a+1
if a>999:
break
x_test=np.array(test_data)
y_test=np.array(test_labels)
x_test=x_test.reshape(1000,224,224,3)
validation_data=[]
validation_labels=[]
a=0
for data_validation, labels_validation in validation_generator:
validation_data.append(data_validation)
validation_labels.append(labels_validation)
a=a+1
if a>999:
break
x_validation=np.array(validation_data)
y_validation=np.array(validation_labels)
x_validation=x_validation.reshape(1000,224,224,3)
四,模型搭建
model1 = tf.keras.models.Sequential([
# 第一层卷积,卷积核为,共16个,输入为150*150*1
tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(224,224,3)),
tf.keras.layers.MaxPooling2D((2,2)),
# 第二层卷积,卷积核为3*3,共32个,
tf.keras.layers.Conv2D(32,(3,3),activation='relu',padding='same'),
tf.keras.layers.MaxPooling2D((2,2)),
# 第三层卷积,卷积核为3*3,共64个,
tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'),
tf.keras.layers.MaxPooling2D((2,2)),
# 数据铺平
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64,activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1,activation='sigmoid')
])
print(model1.summary())
模型summary:
五,模型训练
model1.compile(optimize=tf.keras.optimizers.SGD(0.00001),
loss=tf.keras.losses.binary_crossentropy,
metrics=['acc'])
history1=model1.fit(x_train,y_train,
# validation_split=(0~1) 选择一定的比例用于验证集,可被validation_data覆盖
validation_data=(x_validation,y_validation),
batch_size=10,
shuffle=True,
epochs=10)
model1.save('cats_and_dogs_plain1.h5')
print(history1)
plt.plot(history1.epoch,history1.history.get('acc'),label='acc')
plt.plot(history1.epoch,history1.history.get('val_acc'),label='val_acc')
plt.title('正确率')
plt.legend()
可以看到我们的模型泛化能力还是有点差,测试集的acc能达到0.85以上,验证集却在0.65~0.70之前跳动。
六,模型测试
model1.evaluate(x_validation,y_validation)
最后我们的模型在测试集上的正确率为0.67,可以说还不够好,有点过拟合,可能是训练数据不够多,后续可以数据增广或者从验证集、测试集中调取一部分数据用于训练模型,可能效果好一些。
来源:https://blog.csdn.net/weixin_52612318/article/details/117376964


猜你喜欢
- 1 简介kepler.gl作为开源地理空间数据可视化神器,也一直处于活跃的迭代开发状态下。而在前不久,kepler.gl正式发布了其2.4.
- 一、系统环境yum update升级以后的系统版本为[root@yl-web yl]# cat /etc/redhat-release Ce
- 代码如下# -*- coding = utf-8 -*-# @time:2020/5/28/028 21:00# Author:cyx# @
- 在使用 SQLAlchemy 时,那些看似很小的选择可能对这种对象关系映射工具包的性能产生重要影响。对象关系映射Object-relatio
- 通用load/write方法手动指定选项Spark SQL的DataFrame接口支持多种数据源的操作。一个DataFrame可以进行RDD
- SQL2005 Express 没了「企业管理器」和「查询分析器」 SQL2005 分五个版本,如下所列: 1.Enterprise(企业版
- 当你标记了翻译字符串,你就需要写出(或获取已有的)对应的语言翻译信息。 这里就是它如何工作的。地域限制Django不支持把你的应用本地化到一
- 前言数据清洗是一项复杂且繁琐(kubi)的工作,同时也是整个数据分析过程中最为重要的环节。有人说一个分析项目80%的时间都是在清洗数据,这听
- 新标准的熟悉和入门内容: 还在用 HTML 编写文档?如果是的话,就不符合当前标准了。2000 年&
- 本文实例讲述了Python实现基本数据结构中栈的操作。分享给大家供大家参考,具体如下:#! /usr/bin/env python#codi
- Mysql迁移历史数据记录一下工作中由于业务需要以及系统的数据库模型变更,导致需要做一下历史数据迁移的解决办法需求陈述一共涉及到三张表,分别
- 随着现在宽屏显示器的流行,Flash的全屏模式下,越来越需要考虑到普屏显示器与宽屏显示器的差别。Flash全屏模式有以下特点:窗口最大化,且
- 环境: windows 7 + Python 3.5.2 + Selenium 3.4.2 + Chrome Driver 2.29 + C
- 这篇文章主要介绍了python如何实现不可变字典inmutabledict,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参
- 微信小程序-拍照或选择图片并上传文件调用拍照API:https://mp.weixin.qq.com/debug/wxadoc/dev/ap
- 什么是迭代器能被 next 指针调用,并不断返回下一个值的对象,叫做迭代器。表示为Iterator,迭代器是一个对象类型数据。概念迭代器指的
- Zeroc Ice简介 Zeroc ICE(Internet Communications Engine ,互联网通信引擎)是目前功能比较
- vue控制mock在开发环境使用,在生产环境禁用说下原因mock拦截所有的axios请求,根据请求,做出相应的响应。平时前后端分离开发,我们
- 学习前言我发现不仅有很多的Keras模型,还有很多的PyTorch模型,还是学学Pytorch吧,我也想了解以下tensor到底是个啥。Py
- 一般来说,一个真正的、完整的站点是离不开数据库的,因为实际应用中,需要保存的数据很多,而且这些数据之间往往还有关联,利用数据库来管理这些数据