浅谈keras2 predict和fit_generator的坑
作者:BYR_jiandong 发布时间:2021-05-13 16:30:36
1、使用predict时,必须设置batch_size,否则效率奇低。
查看keras文档中,predict函数原型:
predict(self, x, batch_size=32, verbose=0)
说明:
只使用batch_size=32,也就是说每次将batch_size=32的数据通过PCI总线传到GPU,然后进行预测。在一些问题中,batch_size=32明显是非常小的。而通过PCI传数据是非常耗时的。
所以,使用的时候会发现预测数据时效率奇低,其原因就是batch_size太小了。
经验:
使用predict时,必须人为设置好batch_size,否则PCI总线之间的数据传输次数过多,性能会非常低下。
2、fit_generator
说明:keras 中 fit_generator参数steps_per_epoch已经改变含义了,目前的含义是一个epoch分成多少个batch_size。旧版的含义是一个epoch的样本数目。
如果说训练样本树N=1000,steps_per_epoch = 10,那么相当于一个batch_size=100,如果还是按照旧版来设置,那么相当于
batch_size = 1,会性能非常低。
经验:
必须明确fit_generator参数steps_per_epoch
补充知识:Keras:创建自己的generator(适用于model.fit_generator),解决内存问题
为什么要使用model.fit_generator?
在现实的机器学习中,训练一个model往往需要数量巨大的数据,如果使用fit进行数据训练,很有可能导致内存不够,无法进行训练。
fit_generator的定义如下:
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
其中各项的具体解释,请参考Keras中文文档
我们重点关注的是generator参数:
generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:
一个 (inputs, targets) 元组
一个 (inputs, targets, sample_weights) 元组。
那么,问题来了,如何构建这个generator呢?有以下几种办法:
自己创建一个generator生成器
自己定义一个 Sequence (keras.utils.Sequence) 对象
使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory来生成一个generator
1.自己创建一个generator生成器
使用Keras自带的ImageDataGenerator和.flow/.flow_from_dataframe/.flow_from_directory 灵活度不高,只有当数据集满足一定格式(例如,按照分类文件夹存放)或者具备一定条件时,使用才使用才较为方便。
此时,自己创建一个generator就很重要了,关于python的generator是什么原理,怎么使用,就不加赘述,可以查看python的基本语法。
此处,我们用yield来返回数据组,标签组,从而使fit_generator可以调用我们的generator来成批处理数据。
具体实现如下:
def myGenerator(batch_size):
# loading data
X_train,Y_train=load_data(...)
# data processing
# ................
total_size=X_train.size
#batch_size means how many data you want to train one step
while 1:
for i in range(total_size//batch_size):
yield x_train[i*batch_size:(i+1)*batch_size], y[i*batch_size:(i+1)*batch_size]
return myGenerator
接着你可以调用该生成器:
self._model.fit_generator(myGenerator(batch_size),steps_per_epoch=total_size//batch_size, epochs=epoch_num)
来源:https://blog.csdn.net/lujiandong1/article/details/73556163


猜你喜欢
- Python时间戳操作很多,每次用点时候总是去查,查的麻烦,现在自己也好好归纳一下。我现在刚好有个需求需要获取当天零点时间戳,但是网上查的大
- 前言:在做一个商城项目的时候,需要实现商品搜索功能。说到搜索,第一时间想到的是数据库的 select * from tb_sku where
- 本文介绍了python OpenCV学习笔记直方图反向投影的实现,分享给大家,具体如下:官方文档 – https://docs.opencv
- bcp是SQL Server中负责导入导出数据的一个命令行工具,它是基于DB-Library的,并且能以并行的方式高效地导入导出大批量的数据
- 一、前言准备编写一个篮球游戏,运动员带球跑,跳起投篮。在每帧图片中包括运动员和篮球,使用多帧图片,实现运动员运球跑动的效果。运动员运球跑动作
- 在网上查找大量资料,经过自己的不懈努力,终于测试成功了。原来要在服务器上安装mysql odbc 3.51 ,还有数据库用户名及密码,用下面
- 1.反变换法设需产生分布函数为F(x)的连续随机数X。若已有[0,1]区间均匀分布随机数R,则产生X的反变换公式为:F(x)=r, 即x=F
- 1. goland配置Dockerfile项目中新建Dockerfile文件配置Dockerfile在项目中新建Dockerfile 文件,
- 先把我的browser信息说明一下:这是在opera里about中显示的“浏览器识别”Opera/9.62 (Windows NT 5.1;
- MySQL 如何从表中取出随机数据 以前在群里讨论过这个问题,比较的有意思.mysql的语法真好玩. 他们原来都想用P
- 引言日常开发中,我们经常会使用到group by。亲爱的小伙伴,你是否知道group by的工作原理呢?group by和having有什么
- 本文实例讲述了python异常处理用法。分享给大家供大家参考,具体如下:之前用Java的时候,在容易出错的地方我们经常使用try…catch
- 一:js原型继承四步曲//js模拟类的创建以及继承 //动物(Animal),有头这个属性,eat方法 //名字这个属性 //猫有名字属性,
- string操作在编程中具有极高的频率,那么string中有哪些有用的方法呢?使用strings直接操作Comparefunc Compar
- 前言shape函数是Numpy中的函数,它的功能是读取矩阵的长度,比如shape[0]就是读取矩阵第一维度的长度。直接用.shape可以快速
- 1. 数据抽取的概念2. 数据的分类3. JSON数据概述及解析3.1 JSON数据格式3.2 解析库jsonjson模块是Python内置
- 本文实例讲述了JavaScript观察者模式(publish/subscribe)原理与实现方法。分享给大家供大家参考,具体如下:观察者模式
- 在vue使用echarts时,可能会遇到这样的问题,就是直接刷新浏览器,或者数据变化时,echarts不更新? &nb
- OpenAI,由诸多硅谷大亨联合建立的人工智能非营利组织。2015年马斯克与其他硅谷科技大亨进行连续对话后,决定共同创建OpenAI,希望能
- 背景今天在工作中,同事遇到一个上传图片的问题:系统要求的图片大小不能超过512KB。但是同事又有很多照片。这要是每一个照片都用ps压缩的话,