Tensorflow 多线程与多进程数据加载实例
作者:aszxs 发布时间:2023-12-30 23:53:47
标签:Tensorflow,多线程,多进程
在项目中遇到需要处理超级大量的数据集,无法载入内存的问题就不用说了,单线程分批读取和处理(虽然这个处理也只是特别简单的首尾相连的操作)也会使瓶颈出现在CPU性能上,所以研究了一下多线程和多进程的数据读取和预处理,都是通过调用dataset api实现
1. 多线程数据读取
第一种方法是可以直接从csv里读取数据,但返回值是tensor,需要在sess里run一下才能返回真实值,无法实现真正的并行处理,但如果直接用csv文件或其他什么文件存了特征值,可以直接读取后进行训练,可使用这种方法.
import tensorflow as tf
#这里是返回的数据类型,具体内容无所谓,类型对应就好了,比如我这个,就是一个四维的向量,前三维是字符串类型 最后一维是int类型
record_defaults = [[""], [""], [""], [0]]
def decode_csv(line):
parsed_line = tf.decode_csv(line, record_defaults)
label = parsed_line[-1] # label
del parsed_line[-1] # delete the last element from the list
features = tf.stack(parsed_line) # Stack features so that you can later vectorize forward prop., etc.
#label = tf.stack(label) #NOT needed. Only if more than 1 column makes the label...
batch_to_return = features, label
return batch_to_return
filenames = tf.placeholder(tf.string, shape=[None])
dataset5 = tf.data.Dataset.from_tensor_slices(filenames)
#在这里设置线程数目
dataset5 = dataset5.flat_map(lambda filename: tf.data.TextLineDataset(filename).skip(1).map(decode_csv,num_parallel_calls=15))
dataset5 = dataset5.shuffle(buffer_size=1000)
dataset5 = dataset5.batch(32) #batch_size
iterator5 = dataset5.make_initializable_iterator()
next_element5 = iterator5.get_next()
#这里是需要加载的文件名
training_filenames = ["train.csv"]
validation_filenames = ["vali.csv"]
with tf.Session() as sess:
for _ in range(2):
#通过文件名初始化迭代器
sess.run(iterator5.initializer, feed_dict={filenames: training_filenames})
while True:
try:
#这里获得真实值
features, labels = sess.run(next_element5)
# Train...
# print("(train) features: ")
# print(features)
# print("(train) labels: ")
# print(labels)
except tf.errors.OutOfRangeError:
print("Out of range error triggered (looped through training set 1 time)")
break
# Validate (cost, accuracy) on train set
print("\nDone with the first iterator\n")
sess.run(iterator5.initializer, feed_dict={filenames: validation_filenames})
while True:
try:
features, labels = sess.run(next_element5)
# Validate (cost, accuracy) on dev set
# print("(dev) features: ")
# print(features)
# print("(dev) labels: ")
# print(labels)
except tf.errors.OutOfRangeError:
print("Out of range error triggered (looped through dev set 1 time only)")
break
第二种方法,基于生成器,可以进行预处理操作了,sess里run出来的结果可以直接进行输入训练,但需要自己写一个生成器,我使用的测试代码如下:
import tensorflow as tf
import random
import threading
import numpy as np
from data import load_image,load_wave
class SequenceData():
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
def __len__(self):
return self.L - self.batch_size
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return img1s,img2s,audios,labels
def gen(self):
for i in range(100000):
t = self.__getitem__(i)
yield t
def data_generation(self, batch_datas):
#预处理操作,数据在参数里
return img1s,img2s,audios,labels
#这里的type要和实际返回的数据类型对应,如果在自己的处理代码里已经考虑的batchszie,那这里的batch设为1即可
dataset = tf.data.Dataset().batch(1).from_generator(SequenceData('train.csv').gen,
output_types= (tf.float32,tf.float32,tf.float32,tf.int64))
dataset = dataset.map(lambda x,y,z,w : (x,y,z,w), num_parallel_calls=32).prefetch(buffer_size=1000)
X, y,z,w = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(100000):
a,b,c,d = sess.run([X,y,z,w])
print(a.shape)
不过python的多线程并不是真正的多线程,虽然看起来我是启动了32线程,但运行时的CPU占用如下所示:
还剩这么多核心空着,然后就是第三个版本了,使用了queue来缓存数据,训练需要数据时直接从queue中进行读取,是一个到多进程的过度版本(vscode没法debug多进程,坑啊,还以为代码写错了,在vscode里多进程直接就没法运行),在初始化时启动多个线程进行数据的预处理:
import tensorflow as tf
import random
import threading
import numpy as np
from data import load_image,load_wave
from queue import Queue
class SequenceData():
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
self.queue = Queue(maxsize=20)
for i in range(32):
threading.Thread(target=self.f).start()
def __len__(self):
return self.L - self.batch_size
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return img1s,img2s,audios,labels
def f(self):
for i in range(int(self.__len__()/self.batch_size)):
t = self.__getitem__(i)
self.queue.put(t)
def gen(self):
while 1:
yield self.queue.get()
def data_generation(self, batch_datas):
#数据预处理操作
return img1s,img2s,audios,labels
#这里的type要和实际返回的数据类型对应,如果在自己的处理代码里已经考虑的batchszie,那这里的batch设为1即可
dataset = tf.data.Dataset().batch(1).from_generator(SequenceData('train.csv').gen,
output_types= (tf.float32,tf.float32,tf.float32,tf.int64))
dataset = dataset.map(lambda x,y,z,w : (x,y,z,w), num_parallel_calls=1).prefetch(buffer_size=1000)
X, y,z,w = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
for _ in range(100000):
a,b,c,d = sess.run([X,y,z,w])
print(a.shape)
2. 多进程数据读取
这里的代码和多线程的第三个版本非常类似,修改为启动进程和进程类里的Queue即可,但千万不要在vscode里直接debug!在vscode里直接f5运行进程并不能启动.
from __future__ import unicode_literals
from functools import reduce
import tensorflow as tf
import numpy as np
import warnings
import argparse
import skimage.io
import skimage.transform
import skimage
import scipy.io.wavfile
from multiprocessing import Process,Queue
class SequenceData():
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
self.queue = Queue(maxsize=30)
self.Process_num=32
for i in range(self.Process_num):
print(i,'start')
ii = int(self.__len__()/self.Process_num)
t = Process(target=self.f,args=(i*ii,(i+1)*ii))
t.start()
def __len__(self):
return self.L - self.batch_size
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return img1s,img2s,audios,labels
def f(self,i_l,i_h):
for i in range(i_l,i_h):
t = self.__getitem__(i)
self.queue.put(t)
def gen(self):
while 1:
t = self.queue.get()
yield t[0],t[1],t[2],t[3]
def data_generation(self, batch_datas):
#数据预处理操作
return img1s,img2s,audios,labels
epochs = 2
data_g = SequenceData('train_1.csv',batch_size=48)
dataset = tf.data.Dataset().batch(1).from_generator(data_g.gen,
output_types= (tf.float32,tf.float32,tf.float32,tf.float32))
X, y,z,w = dataset.make_one_shot_iterator().get_next()
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(epochs):
for j in range(int(len(data_g)/(data_g.batch_size))):
face1,face2,voice, labels = sess.run([X,y,z,w])
print(face1.shape)
然后,最后实现的效果
来源:https://blog.csdn.net/qq_22033759/article/details/88772073
0
投稿
猜你喜欢
- 原始需求:例如有一个列表:l = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]希望把它转换成下面这种形式:[1, 2,
- 第一步:字母转数字英文字母转对应数字相对简单,可以在命令行输入一行需要转换的英文字母,然后对每一个字母在整个字母表中匹配,并返回相应的位数,
- 配置如下TEMPLATES = [下面'context_processors': [中添加'django.core.
- 步骤:1. 掌握几种对象及其关系2. 了解每类对象的基本操作方法3. 通过转化关系转化涉及对象1. datetime>>>
- 我使用的是anaconda。我推荐大家使用anaconda,对环境依赖关系处理的比较好。不用浪费太多时间在安装模块上。首先安装pyinsta
- 前两天拉取公司前端代码修改,发现在开发者工具的sources选项里边,居然没有列出来我要调试的js脚本,后来观察了一下,脚本是动态在页面里引
- 这篇文章主要介绍了python matplotlib给图中的点加标签,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习
- CategoricalDtype自定义排序当我们的透视表生成完毕后,有很多情况下需要我们对某列或某行值进行排序。排序有很多种方法。例如sor
- 线程和进程1、线程共享创建它的进程的地址空间,进程有自己的地址空间2、线程可以访问进程所有的数据,线程可以相互访问3、线程之间的数据是独立的
- 本文实例为大家分享了用KNN算法手写体识别的具体代码,供大家参考,具体内容如下#!/usr/bin/python #coding:utf-8
- 如果你只使用一个更新日志,你只须清空日志文件,然后移走旧的更新日志文件到一个备份中,然后启用新的更新日志。用下列方法可以强制服务器启用新的更
- 本文实例讲述了python关于矩阵重复赋值覆盖问题的解决方法。分享给大家供大家参考,具体如下:import itertoolsimport
- 本文实例为大家分享了python+opencv识别图片中足球的方法,供大家参考,具体内容如下先补充下霍夫圆变换的几个参数知识:dp,用来检测
- 在使用easyUI做前端样式展示时,遇到了文件上传的问题,而且是在弹出层中提交表单,想做到不刷新页面,所以选择了使用ajaxFileUplo
- 简介memory_profiler是第三方模块,用于监视进程的内存消耗以及python程序内存消耗的逐行分析。它是一个纯python模块,依
- 源码:#路飞骷髅import turtle as t#黄底帽子t.pu()t.goto(0,200)t.circle(-130,-80)t.
- 工作中遇到的,在一个.c文件中有很多函数,这个.c是自动生成的,需要将所有的函数通过extern放到.h中,每个函数都是UINT32 O_开
- 00. 什么是 freecache?freecache 是一个用 go 语言实现的本地缓存系统(类似于 lru)。相关的 github 地址
- blankzheng的blog:http://www.planabc.net/1、使用fieldset和legend标签在form中,我们经
- 如果要写一个程序,让x1为1,x2为2,然后直到x100为100,你会怎么做?在C这种静态语言里,变量名这个标识符实际上会被编译器直接翻译成