tensorflow使用range_input_producer多线程读取数据实例
作者:lyg5623 发布时间:2022-10-19 16:43:21
标签:tensorflow,多线程,读取,数据
先放关键代码:
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
原理解析:
第一行会产生一个队列,队列包含0到NUM_EXPOCHES-1的元素,如果num_epochs有指定,则每个元素只产生num_epochs次,否则循环产生。shuffle指定是否打乱顺序,这里shuffle=False表示队列的元素是按0到NUM_EXPOCHES-1的顺序存储。在Graph运行的时候,每个线程从队列取出元素,假设值为i,然后按照第二行代码切出array的一小段数据作为一个batch。例如NUM_EXPOCHES=3,如果num_epochs=2,则队列的内容是这样子;
0,1,2,0,1,2
队列只有6个元素,这样在训练的时候只能产生6个batch,迭代6次以后训练就结束。
如果num_epochs不指定,则队列内容是这样子:
0,1,2,0,1,2,0,1,2,0,1,2...
队列可以一直生成元素,训练的时候可以产生无限的batch,需要自己控制什么时候停止训练。
下面是完整的演示代码。
数据文件test.txt内容:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
main.py内容:
import tensorflow as tf
import codecs
BATCH_SIZE = 6
NUM_EXPOCHES = 5
def input_producer():
array = codecs.open("test.txt").readlines()
array = map(lambda line: line.strip(), array)
i = tf.train.range_input_producer(NUM_EXPOCHES, num_epochs=1, shuffle=False).dequeue()
inputs = tf.slice(array, [i * BATCH_SIZE], [BATCH_SIZE])
return inputs
class Inputs(object):
def __init__(self):
self.inputs = input_producer()
def main(*args, **kwargs):
inputs = Inputs()
init = tf.group(tf.initialize_all_variables(),
tf.initialize_local_variables())
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(init)
try:
index = 0
while not coord.should_stop() and index<10:
datalines = sess.run(inputs.inputs)
index += 1
print("step: %d, batch data: %s" % (index, str(datalines)))
except tf.errors.OutOfRangeError:
print("Done traing:-------Epoch limit reached")
except KeyboardInterrupt:
print("keyboard interrput detected, stop training")
finally:
coord.request_stop()
coord.join(threads)
sess.close()
del sess
if __name__ == "__main__":
main()
输出:
step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
Done traing:-------Epoch limit reached
如果range_input_producer去掉参数num_epochs=1,则输出:
step: 1, batch data: ['1' '2' '3' '4' '5' '6']
step: 2, batch data: ['7' '8' '9' '10' '11' '12']
step: 3, batch data: ['13' '14' '15' '16' '17' '18']
step: 4, batch data: ['19' '20' '21' '22' '23' '24']
step: 5, batch data: ['25' '26' '27' '28' '29' '30']
step: 6, batch data: ['1' '2' '3' '4' '5' '6']
step: 7, batch data: ['7' '8' '9' '10' '11' '12']
step: 8, batch data: ['13' '14' '15' '16' '17' '18']
step: 9, batch data: ['19' '20' '21' '22' '23' '24']
step: 10, batch data: ['25' '26' '27' '28' '29' '30']
有一点需要注意,文件总共有35条数据,BATCH_SIZE = 6表示每个batch包含6条数据,NUM_EXPOCHES = 5表示产生5个batch,如果NUM_EXPOCHES =6,则总共需要36条数据,就会报如下错误:
InvalidArgumentError (see above for traceback): Expected size[0] in [0, 5], but got 6
[[Node: Slice = Slice[Index=DT_INT32, T=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"](Slice/input, Slice/begin/_5, Slice/size)]]
错误信息的意思是35/BATCH_SIZE=5,即NUM_EXPOCHES 的取值能只能在0到5之间。
来源:https://blog.csdn.net/lyg5623/article/details/69387917


猜你喜欢
- MySQL的常见操作在这里先做一下总结,已经整合到代码里面,经过检验无误。/*创建一个数据库*/create database xuning
- 本文实例讲述了JS实现获取数组中最大值或最小值功能。分享给大家供大家参考,具体如下:方法一://最小值Array.prototype.min
- 这篇文章主要介绍了Python异常继承关系和自定义异常实现代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 案例展示电影详情,传递电影的id.从search.vue传递到movie.vuemethods: {showMovie(e){var tra
- 前言值类型:所有像int、float、bool和string这些类型都属于值类型,使用这些类型的变量直接指向存在内存中的值,值类型的变量的值
- 本篇文章来介绍一道非常常见的面试题,到底有多常见呢?可能很多面试的开场白就是由此开始的。那就是 new 和 make 这两个内置函数的区别。
- Pytorch统计参数网络参数数量def get_parameter_number(net): total_num
- 本文实例为大家分享了PyQt5实现简单数据标注工具的具体代码,分类用,供大家参考,具体内容如下第一个最大的图片是当前要标注的类别,接下来的两
- 前言今天跟大家介绍一个开源项目:id-maker,主要功能是用来在分布式环境下生成唯一 ID。上周停更了一周,也是用来开发和测试这个项目的相
- 需求需要生成一个宣传的图片分享到朋友圈,这个宣传图片包含二维码,包含不同的背景图片和不同的文字。对于这种图片生成,我们考虑过使用服务端生成,
- 首先来看,ASP读取ACCESS数据库。代码如下:<% @language="VBScript"&nbs
- 如下所示:#获取模型权重for k, v in model_2.state_dict().iteritems(): print("
- 1.主要功能如下:1.classification分类2.Regression回归3.Clustering聚类4.Dimensionalit
- 以下插件是我在项目中经常使用的jQuery插件,不见得是最好的,但是我目前接触到的jQuery插件中最适合我的。01. jQuery.Fle
- 纵观各大编程语言在 2017 年的发展情况,我们会发现涌现出诸如 Go、Swift 这类后起之秀,而其中最为耀眼的当属 Python。之所以
- Socket服务器是网络服务中常用的服务器。使用go语言实现这个业务场景是很容易的。这样的网络通讯,需要一个服务端和至少一个客户端。我们计划
- 官方实现golang 1.8 及以上版本提供了一个创建共享库(shared object)的新工具,称为 Plugins。目前 Plugin
- 前言 大家周末好,今天给大家带来的是Python当中生成器和迭代器的使用。我当初第一次学到迭代器和生成器的时候,并没有太在意,只是觉得这是一
- 虚继承 的概念的提出主要是为了解决C++多继承的问题,举个最简单的例子:class animal{ &nb
- pandas将表中的字符串转成数值型在用pd.read_csv读数据时,将要转换数据类型的列名和类型名构成字典,传给dtypeimport