Tensorflow 多线程设置方式
作者:jingxian 发布时间:2021-09-29 21:53:50
一. 通过 ConfigProto 设置多线程
(具体参数功能及描述见 tensorflow/core/protobuf/config.proto)
在进行 tf.ConfigProto() 初始化时,可以通过设置相应的参数,来控制每个操作符 op 并行计算的线程个数或 session 线程池的线程数。主要涉及的参数有以下三个:
1. intra_op_parallelism_threads 控制运算符op内部的并行
当运算符 op 为单一运算符,并且内部可以实现并行时,如矩阵乘法,reduce_sum 之类的操作,可以通过设置 intra_op_parallelism_threads 参数来并行。
2. inter_op_parallelism_threads 控制多个运算符op之间的并行计算
当有多个运算符 op,并且他们之间比较独立,运算符和运算符之间没有直接的路径 Path 相连。Tensorflow会尝试并行地计算他们,使用由 inter_op_parallelism_threads 参数来控制数量的一个线程池。
在第一次创建会话将设置将来所有会话的线程数,除非是配置了 session_inter_op_thread_pool 选项。
3. session_inter_op_thread_pool 配置会话线程池。
如果会话线程池的 num_threads 为 0,使用 inter_op_parallelism_threads 选项。
二. 通过队列进行数据读取时设置多线程
(具体函数功能及描述见 tensorflow/python/training/input.py)
1. 通过以下函数进行样本批处理时,可以通过设置 num_threads 来设置单个 Reader 多线程读取
1) batch(tensors, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None)
2) maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None, name=None)
3) shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
num_threads=1, seed=None, enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None, name=None)
4) maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
keep_input, num_threads=1, seed=None,
enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None,
name=None)
例:
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
# 生成一个先入先出队列和一个 QueueRunner,生成文件名队列
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义 Reader 和 Decoder
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])
# 使用tf.train.batch() 会为 graph 添加一个样本队列和一个 QueueRunner。
# 经过 Reader 读取文件和 Decoder 解码后数据会进入这个队列,再批量出队。
# tf.train.batch() 这里只有一个 Reader,可以设置多线程
example_batch, label_batch = tf.train.batch([example, label], batch_size=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
e_val,l_val = sess.run([example_batch,label_batch])
print e_val,l_val
coord.request_stop()
coord.join(threads)
2. 通过以下函数进行样本批处理时,可以通过设置 Decoder 和 Reader 的个数来设置多 Reader 读取,其中每个 Reader 使用一个线程
1) batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
shapes=None, dynamic_pad=False, allow_smaller_final_batch=False,
shared_name=None, name=None):
2) maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
enqueue_many=False, shapes=None, dynamic_pad=False,
allow_smaller_final_batch=False, shared_name=None,
name=None)
3) shuffle_batch_join(tensors_list, batch_size, capacity,
min_after_dequeue, seed=None, enqueue_many=False,
shapes=None, allow_smaller_final_batch=False,
shared_name=None, name=None)
4) maybe_shuffle_batch_join(tensors_list, batch_size, capacity,
min_after_dequeue, keep_input, seed=None,
enqueue_many=False, shapes=None,
allow_smaller_final_batch=False, shared_name=None,
name=None)
例:
import tensorflow as tf
filenames = ['A.csv', 'B.csv', 'C.csv']
# 生成一个先入先出队列和一个 QueueRunner,生成文件名队列
filename_queue = tf.train.string_input_producer(filenames, shuffle=False)
# 定义 Reader
reader = tf.TextLineReader()
key, value = reader.read(filename_queue)
#定义了多个 Decoder, 每个 Decoder 跟一个 Reader 相连, 即有多个 Reader
example_list = [tf.decode_csv(value, record_defaults=[['null'], ['null']])
for _ in range(2)] # Decoder 和 Reader 为 2
# 使用tf.train.batch_join() 会为 graph 添加一个样本队列和一个 QueueRunner。
# 经过多个 Reader 读取文件和 Decoder 解码后数据会进入这个队列,再批量出队。
# 使用 tf.train.batch_join(), 可以使用多个 Reader 并行读取数据。每个 Reader 使用一个线程
example_batch, label_batch = tf.train.batch_join(example_list, batch_size=5)
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
e_val,l_val = sess.run([example_batch,label_batch])
print e_val,l_val
coord.request_stop()
coord.join(threads)
来源:https://blog.csdn.net/s_sunnyy/article/details/72924317


猜你喜欢
- 1、创建存储过程 create or replace procedure test(var_name_1 in type,var_name_
- 数字滤波分为 IIR 滤波,和FIR 滤波。FIR 滤波:import scipy.signal as signalimport numpy
- 本文实例讲述了Python多线程下载文件的方法。分享给大家供大家参考。具体实现方法如下:import httplibimport urlli
- <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN&
- pil版:from PIL import Imagefilename = r'E:\data\yangben\0.jpg'i
- <!DOCTYPE html PUBLIC "-//W3C//DTD X
- Golang浮点数比较和运算会出现误差。浮点数储存至内存中时,2的-1、-2……-n次方不能精确的表示小数部分,所以再把这个数从地址中取出来
- 1、登录SMTP服务器首先使用网上的方法(这里使用163邮箱,smtp.163.com是smtp服务器地址,25为端口号):import s
- 前言开发过程中有时需要使用路径数据,虽然python有自己的svg或其他矢量库,但这里只是出于实验的目的,没必要深入研究,所以采用一些简单的
- 前言最近在用python写一个项目,发现一个很恶心的bug,就是同由一个类生成的两个实例之间的数据竟然会相互影响,这让我非常不解。后来联想到
- 这些导航菜单来自于Dribbble网站,出自于世界各地的优秀设计师之手,涵盖了各种不同的风格,个个都非常精美。这里我将这些导航菜单展示出来,
- GetObject 函数返回对文件中 Automation 对象的引用。GetObject([pathname] [, class])参数P
- python 的虚拟环境可以为一个 python 项目提供独立的解释环境、依赖包等资源,既能够很好的隔离不同项目使用不同 python 版本
- 一、Base64编码原理步骤1:将所有字符转化为ASCII码;步骤2:将ASCII码转化为8位二进制;步骤3:将二进制3个归成一组(不足3个
- 修改my.ini或my.conf,将sql-mode="STRICT_TRANS_TABLES,NO_AUTO_CREATE_US
- 前言所谓模糊查询就是不需要用户完整的输入或者说全部输入信息即可提供查询服务,也就是用户可以在边输入的同时边看到提示的信息(其实是查询出来匹配
- 1、argparse是一个python模块,用途是:命令行选项、参数和子命令的解释。2、使用步骤:导入argparse模块,并创建解释器添加
- 一、现象最近在数据库中删除了一张表,重新执行python manage.py migrate时出错,提示不存在这张表。通过查找相关的资料,最
- 前言在此之前,我认为 Python 的类型提示就是一个花瓶,看起来好看,但并没有实质的作用,因为即使类型写错了,或者传错了,程序仍然可以运行
- 最近也是学习了一些爬虫方面的知识。以我自己的理解,通常我们用浏览器查看网页时,是通过浏览器向服务器发送请求,然后服务器响应以后返回一些代码数