Tensorflow使用tfrecord输入数据格式
作者:ruyiweicas 发布时间:2022-06-18 22:55:40
Tensorflow 提供了一种统一的格式来存储数据,这个格式就是TFRecord,上一篇文章中所提到的方法当数据的来源更复杂,每个样例中的信息更丰富的时候就很难有效的记录输入数据中的信息了,于是Tensorflow提供了TFRecord来统一存储数据,接下来我们就来介绍如何使用TFRecord来同意输入数据的格式。
1. TFRecord格式介绍
TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。
2. 将自己的数据转化为TFRecord格式
准备数据
在上一篇中,我们为了像伟大的MNIST致敬,所以选择图像的前缀来进行不同类别的分类依据,但是大多数的情况下,在进行分类任务的过程中,不同的类别都会放在不同的文件夹下,而且类别的个数往往浮动性又很大,所以针对这样的情况,我们现在利用不同类别在不同文件夹中的图像来生成TFRecord.
我们在Iris&Contact这个文件夹下有两个文件夹,分别为iris,contact。对于每个文件夹中存放的是对应的图片
转换数据
数据准备好以后,就开始准备生成TFRecord,具体代码如下:
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
cwd='/home/ruyiwei/Documents/Iris&Contact/'
classes={'iris','contact'}
writer= tf.python_io.TFRecordWriter("iris_contact.tfrecords")
for index,name in enumerate(classes):
class_path=cwd+name+'/'
for img_name in os.listdir(class_path):
img_path=class_path+img_name
img=Image.open(img_path)
img= img.resize((512,80))
img_raw=img.tobytes()
#plt.imshow(img) # if you want to check you image,please delete '#'
#plt.show()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
3. Tensorflow从TFRecord中读取数据
def read_and_decode(filename): # read iris_contact.tfrecords
filename_queue = tf.train.string_input_producer([filename])# create a queue
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#return file_name and file
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#return image and label
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [512, 80, 3]) #reshape image to 512*80*3
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensor
label = tf.cast(features['label'], tf.int32) #throw label tensor
return img, label
4. 将TFRecord中的数据保存为图片
filename_queue = tf.train.string_input_producer(["iris_contact.tfrecords"])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #return file and file_name
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [512, 80, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess:
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord=tf.train.Coordinator()
threads= tf.train.start_queue_runners(coord=coord)
for i in range(20):
example, l = sess.run([image,label])#take out image and label
img=Image.fromarray(example, 'RGB')
img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#save image
print(example, l)
coord.request_stop()
coord.join(threads)
来源:https://blog.csdn.net/best_coder/article/details/70146441


猜你喜欢
- 数组排序排序是指将元素按有序顺序排列。有序序列是拥有与元素相对应的顺序的任何序列,例如数字或字母、升序或降序。NumPy ndarray 对
- 1.先检查系统是否装有mysqlrpm -qa | grep mysql2.下载mysql的repo源(5.7)wget -i -c htt
- 哪行哪业都少不了基本功,都说“马步”要扎得稳。在都快说烂了的以目标用户为中心设计的今天,还是要勤练基本功的。不多说了,先了解下“设计的3个C
- 图片外框特征参数: ①dashed:虚线②dotted:点虚线③solid:实线④double:双线⑤groove:沟
- 存储函数也是过程式对象之一,与存储过程相似。他们都是由SQL和过程式语句组成的代码片段,并且可以从应用程序和SQL中调用。然而,他们也有一些
- 前言大家都知道,一条查询语句走了索引和没走索引的查询效率是非常大的,在我们建好了表,建好了索引后,但是一些不好的sql会导致我们的索引失效,
- 一、oracle oracle服务器有Oracle instace 和Oracle database instance有memory str
- explain显示了mysql如何使用索引来处理select语句以及连接表.可以帮助选择更好的索引和写出更优化的查询语句.使用方法:在sel
- 问题:因为有的友情连接的网站关闭或者网络连接较慢导致连接的LOGO图片显示不出来或者显示很慢.在IE下面老是提示剩下几项没打开,看起来很不舒
- 一、前言听说python很流行,因为有很多模块资源,而且导入模块,操作和理解起来很简单。所以在这里记录一下学习python的过程,我相信最重
- 面部识别----考勤打卡、注册登录、面部支付等等...感觉很高大上,又很方便,下面用python中的框架--django完成一个注册登录的功
- 本文实例讲述了Python基于Socket实现的简单聊天程序。分享给大家供大家参考,具体如下:需求:SCIENCE 和MOOD两个人软件专业
- 坑:在python3.7环境下,通过官方文档安装sanic即扩展插件,但是 sanic-ext包不起作用,具体的表现为:无法打开路由/doc
- 记录遇到的问题;在aliyun上安装MySQL时由于上次错误卸载mysql 导致校验文件出问题;处理方式有几种1到mysql官网下载校验文件
- 前言在python基础知识中有说过,字典是可变的数据类型,其参数又是键对值。setdefault()方法和字典的get()方法在一些地方比较
- for循环是一个循环控制结构,可以有效地编写需要执行的特定次数的循环。语法for循环在Go编程语言中的语法是:for [condition
- 版本更新,原来user里的password字段已经变更为authentication_string版本更新 缘故,好多网上的教程都不适用了,
- 运行效果:完整源码:##import libraryfrom tkinter import *import timefrom playsou
- Git/GitHub/GitHub Desktop相关概念1、GitGit是一款免费的、开源的、最先进的分布式版本控制系统,可以有效、高速地
- 网易最近出的一款自动化UI测试工具:Airtest 挺火的,还受到谷歌的推荐。我试着用了一下,感觉优缺点还是蛮明显的。对初学者来说,能用到的