Tensorflow 实现分批量读取数据
作者:freedom098 发布时间:2023-09-23 23:04:44
之前的博客里使用tf读取数据都是每次fetch一条记录,实际上大部分时候需要fetch到一个batch的小批量数据,在tf中这一操作的明显变化就是tensor的rank发生了变化,我目前使用的人脸数据集是灰度图像,因此大小是92*112的,所以最开始fetch拿到的图像数据集经过reshape之后就是一个rank为2的tensor,大小是92*112的(如果考虑通道,也可以reshape为rank为3的,即92*112*1)。
如果加入batch,比如batch大小为5,那么拿到的tensor的rank就变成了3,大小为5*92*112。
下面规则化的写一下读取数据的一般流程,按照官网的实例,一般把读取数据拆分成两个大部分,一个是函数专门负责读取数据和解码数据,一个函数则负责生产batch。
import tensorflow as tf
def read_data(fileNameQue):
reader = tf.TFRecordReader()
key, value = reader.read(fileNameQue)
features = tf.parse_single_example(value, features={'label': tf.FixedLenFeature([], tf.int64),
'img': tf.FixedLenFeature([], tf.string),})
img = tf.decode_raw(features["img"], tf.uint8)
img = tf.reshape(img, [92,112]) # 恢复图像原始大小
label = tf.cast(features["label"], tf.int32)
return img, label
def batch_input(filename, batchSize):
fileNameQue = tf.train.string_input_producer([filename], shuffle=True)
img, label = read_data(fileNameQue) # fetch图像和label
min_after_dequeue = 1000
capacity = min_after_dequeue+3*batchSize
# 预取图像和label并随机打乱,组成batch,此时tensor rank发生了变化,多了一个batch大小的维度
exampleBatch,labelBatch = tf.train.shuffle_batch([img, label],batch_size=batchSize, capacity=capacity,
min_after_dequeue=min_after_dequeue)
return exampleBatch,labelBatch
if __name__ == "__main__":
init = tf.initialize_all_variables()
exampleBatch, labelBatch = batch_input("./data/faceTF.tfrecords", batchSize=10)
with tf.Session() as sess:
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(100):
example, label = sess.run([exampleBatch, labelBatch])
print(example.shape)
coord.request_stop()
coord.join(threads)
读取数据和解码数据与之前基本相同,针对不同格式数据集使用不同阅读器和解码器即可,后面是产生batch,核心是tf.train.shuffle_batch这个函数,它相当于一个蓄水池的功能,第一个参数代表蓄水池的入水口,也就是逐个读取到的记录,batch_size自然就是batch的大小了,capacity是蓄水池的容量,表示能容纳多少个样本,min_after_dequeue是指出队操作后还可以供随机采样出批量数据的样本池大小,显然,capacity要大于min_after_dequeue,官网推荐:min_after_dequeue + (num_threads + a small safety margin) * batch_size,还有一个参数就是num_threads,表示所用线程数目。
min_after_dequeue这个值越大,随机采样的效果越好,但是消耗的内存也越大。
来源:https://blog.csdn.net/freedom098/article/details/56013625
猜你喜欢
- 这篇论坛文章详细的讲解了使用SQL Server 2008管理非结构化数据的具体方法,更多内容请参考下文:microsoft SQL Ser
- 文末附完整源代码实现过程...想实现这样一个功能,然后pyqt5中又没有现成的组件可以使用,于是就想着只能通过绘图的方式来实现。说到绘图的话
- 大家一定使用过 phpmyadmin 里面的数据库导入,导出功能,非常方便。但是在实际应用中,我发现如下几个问题: 1、数据库超过一定尺寸,
- Python操作Excel之openpyxlopenpyxl是一个Python库,用来读写Excel2010 xlsx/xlsm/xltx/
- 支付宝或者微信支付导出的收款二维码,除了二维码部分,还有很大一块背景图案,例如下面就是微信支付的收款二维码:有时候我们仅仅只想要图片中间的方
- 1.认识数组数组就是某类数据的集合,数据类型可以是整型、字符串、甚至是对象Javascript不支持多维数组,但是因为数组里面可以包含对象(
- 分页是每一个程序需要去理解的东西,学习过的几门语言中我发现分页原理都是一样的,下面为php初学者分析一下php分页实现与最后面补充了一个超级
- 我就废话不多说了,大家还是直接看代码吧!#coding=utf-8import threadingimport timeimport cx_
- 为什么要使用滤波消除图像中的噪声成分叫作图像的平滑化或滤波操作。信号或图像的能量大部分集中在幅度谱的低频和中频段是很常见的,而在较高频段,感
- 对于部署在百度应用引擎BAE上的项目,使用百度云存储BCS(Baidu Cloud Storage)是不错的存储方案。百度云存储已有Pyth
- 前言我们前面对matplotlib模块底层结构学习,对其pyplot类(脚本层)类提供的绘制折线图、柱状图、饼图、直方图等统计图表的相关方法
- php获取图片的exif信息,php自带一个exif_read_data函数可以用来读取图片的exif信息,代码来自php手册<?ph
- 目前计算机中用得最广泛的字符集及其编码,是由美国国家标准局(ANSI)制定的ASCII码(American Stand ard C
- BeautifulSoup简介Beautiful Soup是python的一个库,最主要的功能是从网页抓取数据。官方解释如下:Beautif
- 看了下网上有很多关于模拟登录淘宝,但是基本都是使用scrapy、pyppeteer、selenium等库来模拟登录,但是目前我们还没有讲到这
- oracle命令删除用户:connect / as sysdba; shutdown abort; startup;&n
- js监听浏览器回车事件,可以支持ie6+,火狐,谷歌等浏览器。<html><head><script type
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 15 - SlidersMooTools 1.2的
- 背景:准备给长辈买个手机,有关手机大小,网购平台基本只有手机尺寸和分辨率的文本数据,因而对手机屏幕大小没有直观感受,虽然网上有比较手机大小的
- Profile 和 cProfile在 Python 标准库里面有两个模块可以用来做性能测试。1. 一个是 Profile,它是一个纯 Py