Tensorflow分类器项目自定义数据读入的实现
作者:泠音 发布时间:2023-06-21 20:21:53
标签:Tensorflow,分类器,数据读入
在照着Tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。
首先提一下需要用到的模块:
import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:
IMG_SIZE_X = 30
IMG_SIZE_Y = 30
其次确定你图片的方式目录:
image_path = r'D:\Projects\ImageClassifier\data\set'
path = ".\data"
# 你也可以使用相对路径的方式
# image_path =os.path.join(path, "set")
目录下的结构如下:
相应的label.txt如下:
动漫
风景
美女
物语
樱花
接下来是接在labels.txt,如下:
label_name = "labels.txt"
label_path = os.path.join(path, label_name)
class_names = np.loadtxt(label_path, type(""))
这里简便起见,直接利用了numpy的loadtxt函数直接加载。
之后便是正式处理图片数据了,注释就写在里面了:
re_load = False
re_build = False
# re_load = True
re_build = True
data_name = "data.npz"
data_path = os.path.join(path, data_name)
model_name = "model.h5"
model_path = os.path.join(path, model_name)
count = 0
# 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
if not os.path.exists(data_path) or re_load:
labels = []
images = []
print('Handle images')
# 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
for index, name in enumerate(class_names):
# 这里是拼接后的子目录path
classpath = os.path.join(image_path, name)
# 先判断一下是否是目录
if not os.path.isdir(classpath):
continue
# limit是测试时候用的这里可以去除
limit = 0
for image_name in os.listdir(classpath):
if limit >= max_size:
break
# 这里是拼接后的待处理的图片path
imagepath = os.path.join(classpath, image_name)
count = count + 1
limit = limit + 1
# 利用Image打开图片
img = Image.open(imagepath)
# 缩放到你最初确定要处理的图片分辨率大小
img = img.resize((IMG_SIZE_X, IMG_SIZE_Y))
# 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
img = img.convert("L")
# 转为numpy数组
img = np.array(img)
# 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
img = np.reshape(img, (1, IMG_SIZE_X, IMG_SIZE_Y))
# 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
labels.append([index])
# 添加到images中,最后统一处理
images.append(img)
# 循环中一些状态的输出,可以去除
print("{} class: {} {} limit: {} {}"
.format(count, index + 1, class_names[index], limit, imagepath))
# 最后一次性将images和labels都转换成numpy数组
npy_data = np.array(images)
npy_labels = np.array(labels)
# 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
np.savez(data_path, x=npy_data, y=npy_labels)
print("Save images by npz")
else:
# 如果存在序列化号的数据,便直接读取,提高速度
npy_data = np.load(data_path)["x"]
npy_labels = np.load(data_path)["y"]
print("Load images by npz")
image_data = npy_data
labels_data = npy_labels
到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:
# 最后一步就是将原始数据分成训练数据和测试数据
train_images, test_images, train_labels, test_labels = \
train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
这里将相关信息打印的方法也附上:
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Image Data", image_data.shape))
print("%-28s %-s" % ("Labels Data", labels_data.shape))
print("=================================================================")
print('Split train and test data,p=%')
print("_________________________________________________________________")
print("%-28s %-s" % ("Name", "Shape"))
print("=================================================================")
print("%-28s %-s" % ("Train Images", train_images.shape))
print("%-28s %-s" % ("Test Images", test_images.shape))
print("%-28s %-s" % ("Train Labels", train_labels.shape))
print("%-28s %-s" % ("Test Labels", test_labels.shape))
print("=================================================================")
之后别忘了归一化哟:
print("Normalize images")
train_images = train_images / 255.0
test_images = test_images / 255.0
最后附上读取自定义数据的完整代码:
import os
import keras
import matplotlib.pyplot as plt
from PIL import Image
from keras.layers import *
from keras.models import *
from keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
re_load = False
re_build = False
# re_load = True
re_build = True
epochs = 50
batch_size = 5
count = 0
max_size = 2000000000
来源:https://segmentfault.com/a/1190000018099185
0
投稿
猜你喜欢
- Web 标准要求一览表Russ WeakleyJjgod Jiang14-Aug-2004目录1 Web 标准,不仅仅是“不用表格的站点”2
- 简单的仿图片验证码,适合新手简单的仿图片验证码演示,很容易被破解,实用性不大,但拿出来给新手学习一下还是不错的:JScript.Asp代码示
- 很多朋友想用SQL2000数据库的编程方法,但是却又苦于自己是学ACCESS的,对SQL只是一点点的了解而已,这里我给大家提供以下参考---
- 可怜我的C盘本来只有8.XG,所以不得不卸载掉它。卸载掉本身没啥问题,只是昨晚突然发现 Sql Server 2008 R2 Managem
- linux平台及windows平台mysql重启方 * inux下重启MySQL的正确方法:1、通过rpm包安装的MySQLservice m
- 很多时候,我们都在说设计需要引导用户,尤其是在对初级用户的引导上,很大程度决定着产品能否快速聚拢用户的可能;但同样很多时候,用户并不需要引导
- 下面是模板的一般形式,显示了指定 SQL 查询和 XPath 查询的方式: <ROOT xmlns:sql="ur
- TXT文本文件,WORD文档点击后弹出另存为对话框,然后保存下载,而不是在浏览器中打开的asp实现方法,使用了asp中的stream对象,同
- 之前写过的组织结构和组织体系都太抽象了,读到标签系统我才有那种“略懂”的感觉。哈哈…书上提到的标签包括:导航情境式链接:常见的“更多”这种用
- 很多朋友都有过制作网页的经历,如今,众多网页的设计都用到了表格。这样不仅有利于网页的维护,同时,提高了网页的观赏性。在众多网页制作风格中,细
- 如何做一个密码“生成器”?randompassword.asp<% Dim i, intNum,&nbs
- 有没有想过用尺子来直接量网页上的区块间距,文字行高?屏幕标尺就是干这个的。这个功能非常适合F2E在调试样式尺寸的时候使用。打开屏幕标尺,屏幕
- 本文介绍了在js和asp中使用FileSystemObject(fso)来: 创建、添加或删除数据,以及读取文件; 移动、复制和删除文件;创
- 一、定位 oracle分两大块,一块是开发,一块是管理。开发主要是写写存储过程、触发器什么的,还有就是用Oracle的Develop工具做f
- 在XHTML标签中有一些标签的作用是相似的,当然这里的相似是指语义相似,以至于很多人都不清楚这些相似的标签如何使用,那么今天的主题就是分解相
- 如何实现在下拉菜单里输入文字? 用这个代码试试看,应该可以的:<script>function pp(){se.opt
- 先看看CSS框架的利与弊前段时间一直在讨论CSS框架。很多朋友看了那三篇文章后提了不少自己意见。特别是一位北京的朋友A君,他有一个小的团体,
- 代码如下:SELECT [StartDate] FROM [dbo].[udf_Week](2012,2012) WHERE [
- 如果你正在运行使用MySQL的Web应用程序,那么你把密码或者其他敏感信息保存在应用程序里的机会就很大。保护这些数据免受黑客或者窥探者的获取
- 前两天有一位网友问我一个关于Javascript中++操作符的问题,他的代码大致是这样的ADS.addEvent(window,'c