浅谈Keras中fit()和fit_generator()的区别及其参数的坑
作者:MrLeaper 发布时间:2022-04-18 07:22:26
1、fit和fit_generator的区别
首先Keras中的fit()函数传入的x_train和y_train是被完整的加载进内存的,当然用起来很方便,但是如果我们数据量很大,那么是不可能将所有数据载入内存的,必将导致内存泄漏,这时候我们可以用fit_generator函数来进行训练。
下面是fit传参的例子:
history = model.fit(x_train, y_train, epochs=10,batch_size=32,
validation_split=0.2)
这里需要给出epochs和batch_size,epoch是这个数据集要被轮多少次,batch_size是指这个数据集被分成多少个batch进行处理。
最后可以给出交叉验证集的大小,这里的0.2是指在训练集上占比20%。
fit_generator函数必须传入一个生成器,我们的训练数据也是通过生成器产生的,下面给出一个简单的生成器函数:
batch_size = 128
def generator():
while 1:
row = np.random.randint(0,len(x_train),size=batch_size)
x = np.zeros((batch_size,x_train.shape[-1]))
y = np.zeros((batch_size,))
x = x_train[row]
y = y_train[row]
yield x,y
这里的生成器函数我产生的是一个batch_size为128大小的数据,这只是一个demo。如果我在生成器里没有规定batch_size的大小,就是每次产生一个数据,那么在用fit_generator时候里面的参数steps_per_epoch是不一样的。
这里的坑我困惑了好久,虽然不是什么大问题
下面是fit_generator函数的传参:
history = model.fit_generator(generator(),epochs=epochs,steps_per_epoch=len(x_train)//(batch_size*epochs))
2、batch_size和steps_per_epoch的区别
首先batch_size = 数据集大小/steps_per_epoch的,如果我们在生成函数里设置了batch_size的大小,那么在fit_generator传参的时候,,steps_per_epoch=len(x_train)//(batch_size*epochs)
我得完整demo代码:
from keras.datasets import imdb
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras import layers
import numpy as np
import random
from sklearn.metrics import f1_score,accuracy_score
max_features = 10000
maxlen = 500
batch_size = 32
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = pad_sequences(x_train,maxlen=maxlen)
x_test = pad_sequences(x_test,maxlen=maxlen)
def generator():
while 1:
row = np.random.randint(0,len(x_train),size=batch_size)
x = np.zeros((batch_size,x_train.shape[-1]))
y = np.zeros((batch_size,))
x = x_train[row]
y = y_train[row]
yield x,y
# generator()
model = Sequential()
model.add(layers.Embedding(max_features,32,input_length=maxlen))
model.add(layers.GRU(64,return_sequences=True))
model.add(layers.GRU(32))
# model.add(layers.Flatten())
# model.add(layers.Dense(32,activation='relu'))
model.add(layers.Dense(1,activation='sigmoid'))
model.compile(optimizer='rmsprop',loss='binary_crossentropy',metrics=['acc'])
print(model.summary())
# history = model.fit(x_train, y_train, epochs=1,batch_size=32, validation_split=0.2)
history = model.fit_generator(generator(),epochs=1,steps_per_epoch=len(x_train)//(batch_size))
print(model.evaluate(x_test,y_test))
y = model.predict_classes(x_test)
print(accuracy_score(y_test,y))
补充:model.fit_generator()详细解读
如下所示:
from keras import models
model = models.Sequential()
首先
利用keras,搭建顺序模型,具体搭建步骤省略。完成搭建后,我们需要将数据送入模型进行训练,送入数据的方式有很多种,models.fit_generator()是其中一种方式。
具体说,model.fit_generator()是利用生成器,分批次向模型送入数据的方式,可以有效节省单次内存的消耗。
具体函数形式如下:
fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, \
callbacks=None, validation_data=None, validation_steps=None,\
class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)
参数解释:
generator:一般是一个生成器函数;
steps_per_epochs:是指在每个epoch中生成器执行生成数据的次数,若设定steps_per_epochs=100,这情况如下图所示;
epochs:指训练过程中需要迭代的次数;
verbose:默认值为1,是指在训练过程中日志的显示模式,取 1 时表示“进度条模式”,取2时表示“每轮一行”,取0时表示“安静模式”;
validation_data, validation_steps指验证集的情况,使用方式和generator, steps_per_epoch相同;
models.fit_generator()会返回一个history对象,history.history 属性记录训练过程中,连续 epoch 训练损失和评估值,以及验证集损失和评估值,可以通过以下方式调取这些值!
acc = history.history["acc"]
val_acc = history.history["val_acc"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
来源:https://blog.csdn.net/mlp750303040/article/details/89207658


猜你喜欢
- 迭代数组NumPy中引入了 nditer 对象来提供一种对于数组元素的访问方式。一、单数组迭代1. 使用 nditer 访问数组的每个元素&
- 1. 实例描述通过爬虫获取网页的信息时,有时需要登录网页后才可以获取网页中的可用数据,例如获取 GitHub 网页中的注册号码时,就需要先登
- 说明:该篇博客是博主一字一码编写的,实属不易,请尊重原创,谢谢大家!一丶说明测试条件:需要有GitHub账号以及在本地安装了Git工具,无论
- Pygame 中提供了一个draw模块用来绘制一些简单的图形状,比如矩形、多边形、圆形、直线、弧线等。pygame.draw模块的常用方法如
- 为什么我也要说SQL Server的并行:这几天园子里写关于SQL Server并行的文章很多,不管怎么样,都让人对并行操作有了更深刻的认识
- $("input").attr("checked","checked") 设置以
- 本文实例讲述了python使用xmlrpclib模块实现对百度google的ping功能。分享给大家供大家参考。具体分析如下:最近在做SEO
- 一般情况下,mysql会默认提供多种存储引擎,可以通过下面的查看:1)查看mysql是否安装了innodb插件。通过下面的命令结果可知,已经
- 由于Rosenblatt感知器的局限性,对于非线性分类的效果不理想。为了对线性分类无法区分的数据进行分类,需要构建多层感知器结构对数据进行分
- 1.使用open()函数打开文件夹在读取一个文件的内容之前,需要先打开这个文件。在Python程序中可以通过内置函数open()来打开一个文
- 1.实例方法Python 的实例方法用得最多,也最常见。我们先来看 Python 的实例方法。class Kls(object): &nbs
- 参数Parameters解析响应时间resolveTimeout 数据类型:长整型。简单地说就是程序对目标主机的名字解析解析的一个过程时间。
- 简而言之,channel维护了一个带指针的接受和发送的队列,其中包含mutex锁保证并发安全,数据类型,元素个数,元素大小,channel状
- try ...except 是最常见的捕获处理异常的结构,其主要作用是将可能出现问题的代码块用try :包裹起来,不至于出现错误让程序崩溃,
- 一:绑定方法:其特点是调用方本身自动作为第一个参数传入1.绑定到对象的方法:调用方是一个对象,该对象自动传入2.方法绑定到类:调用方是类,类
- SELECT TABLE_SCHEMA,TABLE_NAME FROM information_schema.`COLUMNS` WHERE
- 本游戏程序实现的功能为本地二人对弈中国象棋,实现语言为javascript+VML,在windows 2000 pro+IE 6sp1的环境
- 一 multiprocessing模块介绍python中的多线程无法利用多核优势,如果想要充分地使用多核CPU的资源(os.cpu\_cou
- createTrackbar是Opencv中的API,其可在显示图像的窗口中快速创建一个滑动控件,用于手动调节阈值,具有非常直观的效果。具体
- 目录前言什么是socket?如何在 Python 中创建 socket 对象?Python 的套接字库中有多少种可用的套接字方法?服务器套接