Tensorflow 多线程与多进程数据加载实例
作者:aszxs 发布时间:2023-12-30 23:53:47
标签:Tensorflow,多线程,多进程
在项目中遇到需要处理超级大量的数据集,无法载入内存的问题就不用说了,单线程分批读取和处理(虽然这个处理也只是特别简单的首尾相连的操作)也会使瓶颈出现在CPU性能上,所以研究了一下多线程和多进程的数据读取和预处理,都是通过调用dataset api实现
1. 多线程数据读取
第一种方法是可以直接从csv里读取数据,但返回值是tensor,需要在sess里run一下才能返回真实值,无法实现真正的并行处理,但如果直接用csv文件或其他什么文件存了特征值,可以直接读取后进行训练,可使用这种方法.
import tensorflow as tf
#这里是返回的数据类型,具体内容无所谓,类型对应就好了,比如我这个,就是一个四维的向量,前三维是字符串类型 最后一维是int类型
record_defaults = [[""], [""], [""], [0]]
def decode_csv(line):
parsed_line = tf.decode_csv(line, record_defaults)
label = parsed_line[-1] # label
del parsed_line[-1] # delete the last element from the list
features = tf.stack(parsed_line) # Stack features so that you can later vectorize forward prop., etc.
#label = tf.stack(label) #NOT needed. Only if more than 1 column makes the label...
batch_to_return = features, label
return batch_to_return
filenames = tf.placeholder(tf.string, shape=[None])
dataset5 = tf.data.Dataset.from_tensor_slices(filenames)
#在这里设置线程数目
dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv,num_parallel_calls=15))
dataset5 = dataset5.shuffle(buffer_size=1000)
dataset5 = dataset5.batch(32) #batch_size
iterator5 = dataset5.make_initializable_iterator()
next_element5 = iterator5.get_next()
#这里是需要加载的文件名
training_filenames = ["train.csv"]
validation_filenames = ["vali.csv"]
with tf.Session() as sess:
for _ in range(2):
#通过文件名初始化迭代器
sess.run(iterator5.initializer, feed_dict={filenames: training_filenames})
while True:
try:
#这里获得真实值
features, labels = sess.run(next_element5)
# Train...
# print("(train) features: ")
# print(features)
# print("(train) labels: ")
# print(labels)
except tf.errors.OutOfRangeError:
print("Out of range error triggered (looped through training set 1 time)")
break
# Validate (cost, accuracy) on train set
print("\nDone with the first iterator\n")
sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames})
while True:
try:
features, labels = sess.run(next_element5)
# Validate (cost, accuracy) on dev set
# print("(dev) features: ")
# print(features)
# print("(dev) labels: ")
# print(labels)
except tf.errors.OutOfRangeError:
print("Out of range error triggered (looped through dev set 1 time only)")
break
第二种方法,基于生成器,可以进行预处理操作了,sess里run出来的结果可以直接进行输入训练,但需要自己写一个生成器,我使用的测试代码如下:
import tensorflow as tf
import random
import threading
import numpy as np
from data import load_image,load_wave
class SequenceData():
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
def __len__(self):
return self.L - self.batch_size
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return img1s,img2s,audios,labels
def gen(self):
for i in range(100000):
t = self.__getitem__(i)
yield t
def data_generation(self, batch_datas):
#预处理操作,数据在参数里
return img1s,img2s,audios,labels
#这里的type要和实际返回的数据类型对应,如果在自己的处理代码里已经考虑的batchszie,那这里的batch设为1即可
dataset = tf.data.Dataset().batch(1).from_generator(SequenceData('train.csv').gen,
output_types= (tf.float32,tf.float32,tf.float32,tf.int64))
dataset = dataset.map(lambda x,y,z,w : (x,y,z,w), num_parallel_calls=32).prefetch(buffer_size=1000)
X, y,z,w = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(100000):
a,b,c,d = sess.run([X,y,z,w])
print(a.shape)
不过python的多线程并不是真正的多线程,虽然看起来我是启动了32线程,但运行时的CPU占用如下所示:
还剩这么多核心空着,然后就是第三个版本了,使用了queue来缓存数据,训练需要数据时直接从queue中进行读取,是一个到多进程的过度版本(vscode没法debug多进程,坑啊,还以为代码写错了,在vscode里多进程直接就没法运行),在初始化时启动多个线程进行数据的预处理:
import tensorflow as tf
import random
import threading
import numpy as np
from data import load_image,load_wave
from queue import Queue
class SequenceData():
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
self.queue = Queue(maxsize=20)
for i in range(32):
threading.Thread(target=self.f).start()
def __len__(self):
return self.L - self.batch_size
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return img1s,img2s,audios,labels
def f(self):
for i in range(int(self.__len__()/self.batch_size)):
t = self.__getitem__(i)
self.queue.put(t)
def gen(self):
while 1:
yield self.queue.get()
def data_generation(self, batch_datas):
#数据预处理操作
return img1s,img2s,audios,labels
#这里的type要和实际返回的数据类型对应,如果在自己的处理代码里已经考虑的batchszie,那这里的batch设为1即可
dataset = tf.data.Dataset().batch(1).from_generator(SequenceData('train.csv').gen,
output_types= (tf.float32,tf.float32,tf.float32,tf.int64))
dataset = dataset.map(lambda x,y,z,w : (x,y,z,w), num_parallel_calls=1).prefetch(buffer_size=1000)
X, y,z,w = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(100000):
a,b,c,d = sess.run([X,y,z,w])
print(a.shape)
2. 多进程数据读取
这里的代码和多线程的第三个版本非常类似,修改为启动进程和进程类里的Queue即可,但千万不要在vscode里直接debug!在vscode里直接f5运行进程并不能启动.
from __future__ import unicode_literals
from functools import reduce
import tensorflow as tf
import numpy as np
import warnings
import argparse
import skimage.io
import skimage.transform
import skimage
import scipy.io.wavfile
from multiprocessing import Process,Queue
class SequenceData():
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
self.queue = Queue(maxsize=30)
self.Process_num=32
for i in range(self.Process_num):
print(i,'start')
ii = int(self.__len__()/self.Process_num)
t = Process(target=self.f,args=(i*ii,(i+1)*ii))
t.start()
def __len__(self):
return self.L - self.batch_size
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return img1s,img2s,audios,labels
def f(self,i_l,i_h):
for i in range(i_l,i_h):
t = self.__getitem__(i)
self.queue.put(t)
def gen(self):
while 1:
t = self.queue.get()
yield t[0],t[1],t[2],t[3]
def data_generation(self, batch_datas):
#数据预处理操作
return img1s,img2s,audios,labels
epochs = 2
data_g = SequenceData('train_1.csv',batch_size=48)
dataset = tf.data.Dataset().batch(1).from_generator(data_g.gen,
output_types= (tf.float32,tf.float32,tf.float32,tf.float32))
X, y,z,w = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(epochs):
for j in range(int(len(data_g)/(data_g.batch_size))):
face1,face2,voice, labels = sess.run([X,y,z,w])
print(face1.shape)
然后,最后实现的效果
来源:https://blog.csdn.net/qq_22033759/article/details/88772073


猜你喜欢
- 据国外媒体报道,相较于IE8浏览器,微软最新一代浏览器IE9的最大改进就是硬件加速HTML5。微软承诺,通过利用IE9中的硬件加速功能,开发
- 一般来说一个系统最先出现瓶颈的点很可能是数据库。比如我们的生产系统并发量很高在跑一段时间后,数据库中某些表的数据量会越来越大。海量的数据会严
- 前端技术层(图片有点偏激,仅供参考)Javascript和DOM关系很暧昧,弄不明白!CSS和HTML
- 如何下载:我先去MySQL首页下载最新版本的MySQL-链接:https://www.mysql.com/downloads/进入此界面下载
- 【实用系列】-- 胖页面载入,加载JavaScript效果整理了一下代码,做了一些优化,算是最终版了。完全不需要对其他文件做任何修改,就是所
- 首先看官网的DataFrame.plot( )函数DataFrame.plot(x=None, y=None, kind='line
- 分别针对ie和火狐分别作了对xml文档和xml字符串的解析,所有代码都注释掉了,想看哪部分功能,去掉注释就可以了。至于在ajax环境下解析x
- 引言很多人都是从php转过来的吧,不知道你们有没有发现,go界的orm并没有像php的orm一样好用。这篇文章里,我们认真的讨论下这个问题,
- 场景:把一个时间字符串转成Date,存进Mysql。时间天数会比实际时间少1天,也可能是小时少了13-14小时Mysql的时区是CST(使用
- 主要原理:调整dicom的窗宽,使之各个像素点上的灰度值缩放至[0,255]范围内。使用到的python库:SimpleITK下面是一个将d
- 触发器是一种特殊的存储过程,触发器主要是通过事件进行触发而被自动调用执行,而存储过程必须通过存储过程的名称被调用。一、触发器的定义触发器是在
- 1、引言小 * 丝:鱼哥,你说咱们发快递时填写的地址信息,到后台怎么能看清楚写的对不对呢?小鱼:这种事情还要问? 你没在电商行业混过??小 * 丝:
- 目录1、准备基础数据2、一次性展示数据3、引入分页器附:drf分页器的使用1.1 PageNumberPagination1.2 Limit
- 本文较为详细的讲述了Python实现远程调用MetaSploit的方法,对Python的学习来说有很好的参考价值。具体实现方法如下:(1)安
- 1、残差连接是目前常用的组件,解决了大规模深度学习模型梯度消失和瓶颈问题。通常,在10层以上的模型中追加残差连接可能有帮助。from ker
- 前言go mod tidy的作用是把项目所需要的依赖添加到go.mod,并删除go.mod中,没有被项目使用的依赖。Tidy makes s
- XML 是严格又自由的标记语言。我们都习惯于它的自由特性,自己想怎么定义都行,设计上非常自由,从不会因为它的标记特性约束到设计灵感的发挥。对
- Oracle中大文本数据类型Clob 长文本类型 (MySQL中不支持,使用的是text)Blob 二进
- #最近在网上看代码时,出现了@???的代码,看了好久也不知道是什么意思,经过了解原来是装饰器,我给大家举个例子讲解一下,帮助大家快速理解:#
- 本文实例为大家分享了python实现录音功能的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*-import