keras使用Sequence类调用大规模数据集进行训练的实现
作者:aszxs 发布时间:2021-01-03 20:24:35
标签:keras,Sequence,数据集,训练
使用Keras如果要使用大规模数据集对网络进行训练,就没办法先加载进内存再从内存直接传到显存了,除了使用Sequence类以外,还可以使用迭代器去生成数据,但迭代器无法在fit_generation里开启多进程,会影响数据的读取和预处理效率,在本文中就不在叙述了,有需要的可以另外去百度。
下面是我所使用的代码
class SequenceData(Sequence):
def __init__(self, path, batch_size=32):
self.path = path
self.batch_size = batch_size
f = open(path)
self.datas = f.readlines()
self.L = len(self.datas)
self.index = random.sample(range(self.L), self.L)
#返回长度,通过len(<你的实例>)调用
def __len__(self):
return self.L - self.batch_size
#即通过索引获取a[0],a[1]这种
def __getitem__(self, idx):
batch_indexs = self.index[idx:(idx+self.batch_size)]
batch_datas = [self.datas[k] for k in batch_indexs]
img1s,img2s,audios,labels = self.data_generation(batch_datas)
return ({'face1_input_1': img1s, 'face2_input_2': img2s, 'input_3':audios},{'activation_7':labels})
def data_generation(self, batch_datas):
#预处理操作
return img1s,img2s,audios,labels
然后在代码里通过fit_generation函数调用并训练
这里要注意,use_multiprocessing参数是是否开启多进程,由于python的多线程不是真的多线程,所以多进程还是会获得比较客观的加速,但不支持windows,windows下python无法使用多进程。
D = SequenceData('train.csv')
model_train.fit_generator(generator=D,steps_per_epoch=int(len(D)),
epochs=2, workers=20, #callbacks=[checkpoint],
use_multiprocessing=True, validation_data=SequenceData('vali.csv'),validation_steps=int(20000/32))
同样的,也可以在测试的时候使用
model.evaluate_generator(generator=SequenceData('face_test.csv'),steps=int(125100/32),workers=32)
补充知识:keras数据自动生成器,继承keras.utils.Sequence,结合fit_generator实现节约内存训练
我就废话不多说了,大家还是直接看代码吧~
#coding=utf-8
'''
Created on 2018-7-10
'''
import keras
import math
import os
import cv2
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
class DataGenerator(keras.utils.Sequence):
def __init__(self, datas, batch_size=1, shuffle=True):
self.batch_size = batch_size
self.datas = datas
self.indexes = np.arange(len(self.datas))
self.shuffle = shuffle
def __len__(self):
#计算每一个epoch的迭代次数
return math.ceil(len(self.datas) / float(self.batch_size))
def __getitem__(self, index):
#生成每个batch数据,这里就根据自己对数据的读取方式进行发挥了
# 生成batch_size个索引
batch_indexs = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# 根据索引获取datas集合中的数据
batch_datas = [self.datas[k] for k in batch_indexs]
# 生成数据
X, y = self.data_generation(batch_datas)
return X, y
def on_epoch_end(self):
#在每一次epoch结束是否需要进行一次随机,重新随机一下index
if self.shuffle == True:
np.random.shuffle(self.indexes)
def data_generation(self, batch_datas):
images = []
labels = []
# 生成数据
for i, data in enumerate(batch_datas):
#x_train数据
image = cv2.imread(data)
image = list(image)
images.append(image)
#y_train数据
right = data.rfind("\\",0)
left = data.rfind("\\",0,right)+1
class_name = data[left:right]
if class_name=="dog":
labels.append([0,1])
else:
labels.append([1,0])
#如果为多输出模型,Y的格式要变一下,外层list格式包裹numpy格式是list[numpy_out1,numpy_out2,numpy_out3]
return np.array(images), np.array(labels)
# 读取样本名称,然后根据样本名称去读取数据
class_num = 0
train_datas = []
for file in os.listdir("D:/xxx"):
file_path = os.path.join("D:/xxx", file)
if os.path.isdir(file_path):
class_num = class_num + 1
for sub_file in os.listdir(file_path):
train_datas.append(os.path.join(file_path, sub_file))
# 数据生成器
training_generator = DataGenerator(train_datas)
#构建网络
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=2, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(training_generator, epochs=50,max_queue_size=10,workers=1)
来源:https://blog.csdn.net/qq_22033759/article/details/88798423


猜你喜欢
- 背景我们使用MySQL存储了FriendFeed的所有数据。数据库随着用户基数的增长而增长了很多。现在已经存储了超过2.5亿条记
- 如下所示:coupon = models.ForeignKey("Coupon", on_delete=models.C
- 前言go语言并没有面向对象的相关概念,go语言提到的接口和java、c++等语言提到的接口不同,它不会显示的说明实现了接口,没有继承、子类、
- 代码如下:USE TestDB declare @conversation uniqueidentifier while exists (s
- 前言在前几年,如果你和嵌入式开发人员推荐Python,大概会是这样一种场景:A:”诶,老王,你看Python开发这么方便
- 一、前言数据库的数据量达到一定程度之后,为避免带来系统性能上的瓶颈。需要进行数据的处理,采用的手段是分区、分片、分库、分表。二、分片(类似分
- 训练的时候当然用gpu,速度快呀。我想用cpu版的tensorflow跑一下,结果报错,这个错误不太容易看懂。大概意思是没找到一些节点。后来
- python除了关键字(keywords)和内置的类型和函数(builtins),更多的功能是通过libraries(即modules)来提
- 废话不多说了,关键代码如下所示:<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 T
- 如果您还没看过段正淳的css笔记(1)分类之间的横竖线,可以先看看!1、css圆角的做法.为了这个圆角,前段开发们付出的努力是在是太多了.又
- 本文实例讲述了python自动zip压缩目录的方法。分享给大家供大家参考。具体实现方法如下:这段代码来压缩数据库备份文件,没有使用pytho
- 常见的一种应用场景:条件:假设A的shape为[4, 2],B的shape为[5, 2]目的:实现A中的每一行, 减去B中的所有行(broa
- networkx是Python的一个包,用于构建和操作复杂的图结构,提供分析图的算法。图是由顶点、边和可选的属性构成的数据结构,顶点表示数据
- Javascript刷新页面的几种方法: 1. history.go(0) 2. location.reload() 3. location
- 下面我摘录了SQL Server官方教程中的一段关于触发器的文字,确实有用的一点文字描述。 可以定义一个无论何时用INSERT语句向表中插入
- 这里inference两个程序的连接,如目标检测,可以利用一个程序提取候选框,然后把候选框输入到分类cnn网络中。这里常需要进行一定的连接。
- 前言这篇文章主要介绍了Go语言使用swagger生成接口文档的方法,希望能够对大家的学习或工作具有一定的帮助,需要的朋友可以参考下。在前后端
- 导入库和数据首先,我们需要导入PyTorch和PyG库,然后准备好我们的数据。例如,我们可以使用以下方式生成一个简单的随机数据集:from
- 本文列出了HTML4标签的默认样式列表,对网页设计者来说这个应该很有用。原文来自:W3C (http://www.w3.org/TR/CSS
- 在现实的图像操作软件中,经常碰到的不是给出放大多少倍,而是由用户在软件的界面上选择多大的区域,或者选择几个点,那么这样情况下,怎么样来计算出