python之tensorflow手把手实例讲解猫狗识别实现
作者:鑫xing 发布时间:2021-12-26 14:46:11
标签:python,tensorflow,猫狗识别,卷积神经网络
目录
一,猫狗数据集数目构成
二,数据导入
三,数据集构建
四,模型搭建
五,模型训练
六,模型测试
作为tensorflow初学的大三学生,本次课程作业的使用猫狗数据集做一个二分类模型。
一,猫狗数据集数目构成
train | cats:1000 ,dogs:1000 |
---|---|
test | cats: 500,dogs:500 |
validation | cats:500,dogs:500 |
二,数据导入
train_dir = 'Data/train'
test_dir = 'Data/test'
validation_dir = 'Data/validation'
train_datagen = ImageDataGenerator(rescale=1/255,
rotation_range=10,
width_shift_range=0.2, #图片水平偏移的角度
height_shift_range=0.2, #图片数值偏移的角度
shear_range=0.2, #剪切强度
zoom_range=0.2, #随机缩放的幅度
horizontal_flip=True, #是否进行随机水平翻转
# fill_mode='nearest'
)
train_generator = train_datagen.flow_from_directory(train_dir,
(224,224),batch_size=1,class_mode='binary',shuffle=False)
test_datagen = ImageDataGenerator(rescale=1/255)
test_generator = test_datagen.flow_from_directory(test_dir,
(224,224),batch_size=1,class_mode='binary',shuffle=True)
validation_datagen = ImageDataGenerator(rescale=1/255)
validation_generator = validation_datagen.flow_from_directory(
validation_dir,(224,224),batch_size=1,class_mode='binary')
print(train_datagen)
print(test_datagen)
print(train_datagen)
三,数据集构建
我这里是将ImageDataGenerator类里的数据提取出来,将数据与标签分别存放在两个列表,后面在转为np.array,也可以使用model.fit_generator,我将数据放在内存为了后续调参数时模型训练能更快读取到数据,不用每次训练一整轮都去读一次数据(应该是这样的…我是这样理解…)
注意我这里的数据集构建后,三种数据都是存放在内存中的,我电脑内存是16g的可以存放下。
train_data=[]
train_labels=[]
a=0
for data_train, labels_train in train_generator:
train_data.append(data_train)
train_labels.append(labels_train)
a=a+1
if a>1999:
break
x_train=np.array(train_data)
y_train=np.array(train_labels)
x_train=x_train.reshape(2000,224,224,3)
test_data=[]
test_labels=[]
a=0
for data_test, labels_test in test_generator:
test_data.append(data_test)
test_labels.append(labels_test)
a=a+1
if a>999:
break
x_test=np.array(test_data)
y_test=np.array(test_labels)
x_test=x_test.reshape(1000,224,224,3)
validation_data=[]
validation_labels=[]
a=0
for data_validation, labels_validation in validation_generator:
validation_data.append(data_validation)
validation_labels.append(labels_validation)
a=a+1
if a>999:
break
x_validation=np.array(validation_data)
y_validation=np.array(validation_labels)
x_validation=x_validation.reshape(1000,224,224,3)
四,模型搭建
model1 = tf.keras.models.Sequential([
# 第一层卷积,卷积核为,共16个,输入为150*150*1
tf.keras.layers.Conv2D(16,(3,3),activation='relu',padding='same',input_shape=(224,224,3)),
tf.keras.layers.MaxPooling2D((2,2)),
# 第二层卷积,卷积核为3*3,共32个,
tf.keras.layers.Conv2D(32,(3,3),activation='relu',padding='same'),
tf.keras.layers.MaxPooling2D((2,2)),
# 第三层卷积,卷积核为3*3,共64个,
tf.keras.layers.Conv2D(64,(3,3),activation='relu',padding='same'),
tf.keras.layers.MaxPooling2D((2,2)),
# 数据铺平
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64,activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1,activation='sigmoid')
])
print(model1.summary())
模型summary:
五,模型训练
model1.compile(optimize=tf.keras.optimizers.SGD(0.00001),
loss=tf.keras.losses.binary_crossentropy,
metrics=['acc'])
history1=model1.fit(x_train,y_train,
# validation_split=(0~1) 选择一定的比例用于验证集,可被validation_data覆盖
validation_data=(x_validation,y_validation),
batch_size=10,
shuffle=True,
epochs=10)
model1.save('cats_and_dogs_plain1.h5')
print(history1)
plt.plot(history1.epoch,history1.history.get('acc'),label='acc')
plt.plot(history1.epoch,history1.history.get('val_acc'),label='val_acc')
plt.title('正确率')
plt.legend()
可以看到我们的模型泛化能力还是有点差,测试集的acc能达到0.85以上,验证集却在0.65~0.70之前跳动。
六,模型测试
model1.evaluate(x_validation,y_validation)
最后我们的模型在测试集上的正确率为0.67,可以说还不够好,有点过拟合,可能是训练数据不够多,后续可以数据增广或者从验证集、测试集中调取一部分数据用于训练模型,可能效果好一些。
来源:https://blog.csdn.net/weixin_52612318/article/details/117376964
0
投稿
猜你喜欢
- 新建一个项目 app02在 app02/ 下创建 urls.py:from django.conf.urls import urlfrom
- 1.1全部php生成结构1.2html中嵌套php总结如下:html和php混写规则:php代码必须包在<?php ?>html
- 根据代码中运行的结果来看,主要由以下几种:1. sum():将array中每个元素相加的结果2. axis对应的是维度的相加。比如:1、ax
- 如下所示:import cv2import osimport numpy as nproot_path = "I:/Images/
- 微软今天宣布正式发布SQL Server 2008服务器软件,这将帮助微软与Oracle 11g,IBM DB2 9.5数据库产品对抗.此前
- 半透明效果有时候会给页面增加不少色彩,特别是Vista盛行之后,半透明效果更加受推崇。在诸多可用于Web浏览的图片格式中,只有PNG格式和G
- 最近在做项目的时候经常会用到定时任务,由于我的项目是使用Java来开发,用的是SpringBoot框架,因此要实现这个定时任务其实并不难。后
- 接口测试中,上传文件的测试场景非常常见。例如:上传头像(图片)、上传文件、上传视频等。下面以一个上传图片的例子为大家讲解如何通过 pytho
- 问题描述给出一个整数数组 nums,请返回其中位数为偶数的数字的个数。示例 1:输入:nums = [12,345,2,6,7896]输出:
- 这一版,对虹软的功能进行了一些封装,添加了人脸特征比对,比对结果保存到文件,和从文件提取特征进行比对,大体功能基本都已经实现,可以进行下一步
- 阅读上一篇:WEB2.0网页制作标准教程(11)不用表格的菜单辛苦了好多天,我们努力学习使用XHTML+CSS来重新设计我们的网站。那么我们
- 大家是否还记得1983年任天堂的著名游戏《超级玛丽》里那个留着胡子的意大利水管工人,还有日本konami公司1987年发行的射击游戏《魂斗罗
- 初学者可以看看。在的img标签有两个属性分别为alt和title,对于很多初学者而言对这两个属性的正确使用都还抱有迷惑,当然这其中一部分原因
- 本文实例讲述了Ubuntu下使用Python实现游戏制作中的切分图片功能。分享给大家供大家参考,具体如下:why拿到一个人物行走的素材,要用
- 学习https://matplotlib.org/gallery/index.html 记录,描述不一定准确,具体请参考官网Matplotl
- 一、数据容器:list(列表)列表内的每一个数据,称之为元素以 [] 作为标识列表内每一个元素之间用, 逗号隔开定义语法:[元素1, 元素2
- 今天使用python写了一个简单的爬虫,用来下载taptap网站的游戏截图。下面说下具体的实现方法。在搜索框中搜索“原神”打开浏览器的开发者
- 文件上传是所有UI自动化测试都要面对的一个头疼问题,今天博主在这里给大家分享下自己处理文件上传的经验,希望能够帮助到广大被文件上传坑住的se
- 天下武功,唯快不破。编程也不例外,你的代码跑的快,你能快速找出代码慢的原因,你的码功就高。安装pip install pyinstrumen
- MySQL服务器有几个影响其操作的参数(变量)。如果缺省的参数值不合适,可以将其修改为对服务器运行环境更合适的值。例如,如果您有大量的内存,