将自己的数据集制作成TFRecord格式教程
作者:v1_vivian 发布时间:2022-02-01 14:49:37
标签:数据集,TFRecord
在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入
此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格
式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。
1. 原本的数据集
此时,我有两类图片,分别是xiansu100,xiansu60,每一类中有10张图片。
2.制作成TFRecord格式
tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。
#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
#生成字符串类型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
#制作TFRecord格式
def createTFRecord(filename,mapfile):
class_map = {}
data_dir = '/home/wc/DataSet/traffic/testTFRecord/'
classes = {'xiansu60','xiansu100'}
#输出TFRecord文件的地址
writer = tf.python_io.TFRecordWriter(filename)
for index,name in enumerate(classes):
class_path=data_dir+name+'/'
class_map[index] = name
for img_name in os.listdir(class_path):
img_path = class_path + img_name #每个图片的地址
img = Image.open(img_path)
img= img.resize((224,224))
img_raw = img.tobytes() #将图片转化成二进制格式
example = tf.train.Example(features = tf.train.Features(feature = {
'label':_int64_feature(index),
'image_raw': _bytes_feature(img_raw)
}))
writer.write(example.SerializeToString())
writer.close()
txtfile = open(mapfile,'w+')
for key in class_map.keys():
txtfile.writelines(str(key)+":"+class_map[key]+"\n")
txtfile.close()
此段代码,运行完后会产生生成的.tfrecord文件。
3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程
#读取train.tfrecord中的数据
def read_and_decode(filename):
#创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
#创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
#从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
_,serialized_example = reader.read(filename_queue)
# print _,serialized_example
#解析读入的一个样例,如果需要解析多个,可以用parse_example
features = tf.parse_single_example(
serialized_example,
features = {'label':tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),})
#将字符串解析成图像对应的像素数组
img = tf.decode_raw(features['image_raw'], tf.uint8)
img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
img = tf.image.per_image_standardization(img)
labels = tf.cast(features['label'], tf.int32)
return img, labels
4. 将图片几个一打包,形成batch
def createBatch(filename,batchsize):
images,labels = read_and_decode(filename)
min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batchsize
image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size=batchsize,
capacity=capacity,
min_after_dequeue=min_after_dequeue
)
label_batch = tf.one_hot(label_batch,depth=2)
return image_batch, label_batch
5.主函数
if __name__ =="__main__":
#训练图片两张为一个batch,进行训练,测试图片一起进行测试
mapfile = "/home/wc/DataSet/traffic/testTFRecord/classmap.txt"
train_filename = "/home/wc/DataSet/traffic/testTFRecord/train.tfrecords"
# createTFRecord(train_filename,mapfile)
test_filename = "/home/wc/DataSet/traffic/testTFRecord/test.tfrecords"
# createTFRecord(test_filename,mapfile)
image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
with tf.Session() as sess:
initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(initop)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess, coord = coord)
try:
step = 0
while 1:
_image_batch,_label_batch = sess.run([image_batch,label_batch])
step += 1
print step
print (_label_batch)
except tf.errors.OutOfRangeError:
print (" trainData done!")
try:
step = 0
while 1:
_test_images,_test_labels = sess.run([test_images,test_labels])
step += 1
print step
# print _image_batch.shape
print (_test_labels)
except tf.errors.OutOfRangeError:
print (" TEST done!")
coord.request_stop()
coord.join(threads)
此时,生成的batch,就可以feed进网络了。
来源:https://blog.csdn.net/v1_vivian/article/details/77898414


猜你喜欢
- 之前的文章介绍了python抓取网页数据并将数据保存到本地excel文件,后续可以将数据保存到数据库(SqlServer、mysql等)中,
- 本文实例讲述了django 框架实现的用户注册、登录、退出功能。分享给大家供大家参考,具体如下:1 用户注册:from django.con
- pandas有groupby分组函数和sort_values排序函数,但是如何对dataframe分组之后排序呢?In [70]: df =
- 在这篇文章里,我们会聊一聊为什么 Python 决定不支持 switch 语句。为什么想要聊这个话题呢?主要是因为 switch 在其它语言
- 本文实例讲述了Python分支语句与循环语句应用。分享给大家供大家参考,具体如下:一、分支语句1、if else语句语法:if 条件判断:
- 在为一个客户排除死锁问题时我遇到了一个有趣的包括InnoDB间隙锁的情形。对于一个WHERE子句不匹配任何行的非插入的写操作中,
- 其实发这篇博感觉并没有什么用,太简单了,会的人不屑看,不会的人自已动动脑子也想到了。但是看着自已的博客已经这么久没更,真心疼~。粗略算下一篇
- 前言大家好,我是苏凉,在前面我们已经学习了网络爬虫并且获取到了数据,接下来当然是对数据进行分析啦,本篇文章带大家进入新的模块:pyhon数据
- 人类学是关于人的研究;社会人类学(social anthropology)是研究人类社会的学科。社会人类学还可以理解成“文化翻译”(the
- 表结构学生表如下:CREATE TABLE `t_student` ( `id` int NOT NULL AUTO_INCRE
- Python编程语言的优点非常多,它的编程特色主要体现在可扩充性方面。那么,在接下来的这篇文章中,我们将会为大家详细介绍一下有关Python
- 最近用python的正则表达式处理了一些文本数据,需要把结果写到文件里面,但是由于文件比较大,所以运行起来花费的时间很长。但是打开任务管理器
- 本文实例讲述了Python tkinter实现的图片移动碰撞动画效果。分享给大家供大家参考,具体如下:先来看看运行效果:具体代码如下:#!/
- 一、问题描述使用vscode,在markdown的预览模式下无法预览网络图片二、本机环境该问题与电脑硬件以及操作系统环境无关。本机markd
- mutations的调用方法直接通过$store.commit调用<button @click="$store.commit
- 思路:在腾讯疫情数据网站F12解析网站结构,使用Python爬取当日疫情数据和历史疫情数据,分别存储到details和history两个my
- 4款JavaScript放大镜特效脚本。准确的说,Anythingzoomer和Bezoom才是正宗的放大镜特效,当鼠标悬浮在图片上时,能放
- 如果在session级保存一个dictionary对象会降低系统的性能,而在application级保存一个dictionary对象会导致w
- 如何利用pandas读取csv数据并绘图导包,常用的numpy和pandas,绘图模块matplotlib,import matplotli
- 前言大家应该都有所体会,随着硬件层面的发展,linux系统多核已经是普通趋势,而mysql是单进程多线程,所以先天上对多进程的利用不是很高,