keras和tensorflow使用fit_generator 批次训练操作
作者:zhang0peter 发布时间:2023-04-09 04:55:59
fit_generator 是 keras 提供的用来进行批次训练的函数,使用方法如下:
model.fit_generator(generator, steps_per_epoch=None, epochs=1,
verbose=1, callbacks=None, validation_data=None, validation_steps=None,
class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False,
shuffle=True, initial_epoch=0)
参数说明:
generator: 一个生成器,或者一个 Sequence (keras.utils.Sequence) 对象的实例, 以在使用多进程时避免数据的重复。 生成器的输出应该为以下之一:
一个(inputs, targets) 元组
一个 (inputs, targets, sample_weights) 元组。
这个元组(生成器的单个输出)组成了单个的 batch。 因此,这个元组中的所有数组长度必须相同(与这一个 batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一个 epoch 的最后一个 batch 往往比其他 batch 要小, 如果数据集的尺寸不能被 batch size 整除。 生成器将无限地在数据集上循环。当运行到第steps_per_epoch 时,记一个 epoch 结束。
steps_per_epoch: 在声明一个 epoch 完成并开始下一个 epoch 之前从 generator产生的总步数(批次样本)。 它通常应该等于你的数据集的样本数量除以批量大小。 对于Sequence,它是可选的:如果未指定,将使用len(generator)作为步数。
epochs: 整数。训练模型的迭代总轮数。一个 epoch 是对所提供的整个数据的一轮迭代,如 steps_per_epoch 所定义。注意,与 initial_epoch 一起使用,epoch 应被理解为「最后一轮」。模型没有经历由 epochs 给出的多次迭代的训练,而仅仅是直到达到索引 epoch 的轮次。
verbose: 0, 1 或 2。日志显示模式。 0 = 安静模式, 1 = 进度条, 2 = 每轮一行。
callbacks: keras.callbacks.Callback 实例的列表。在训练时调用的一系列回调函数。
validation_data: 它可以是以下之一:
验证数据的生成器或Sequence实例
一个(inputs, targets) 元组
一个(inputs, targets, sample_weights) 元组。
在每个 epoch 结束时评估损失和任何模型指标。该模型不会对此数据进行训练。
validation_steps: 仅当 validation_data 是一个生成器时才可用。 在停止前 generator 生成的总步数(样本批数)。 对于 Sequence,它是可选的:如果未指定,将使用 len(generator) 作为步数。
class_weight: 可选的将类索引(整数)映射到权重(浮点)值的字典,用于加权损失函数(仅在训练期间)。 这可以用来告诉模型「更多地关注」来自代表性不足的类的样本。
max_queue_size: 整数。生成器队列的最大尺寸。 如未指定,max_queue_size 将默认为 10。
workers: 整数。使用的最大进程数量,如果使用基于进程的多线程。 如未指定,workers 将默认为 1。如果为 0,将在主线程上执行生成器。
use_multiprocessing: 布尔值。如果 True,则使用基于进程的多线程。 如未指定, use_multiprocessing 将默认为 False。 请注意,由于此实现依赖于多进程,所以不应将不可传递的参数传递给生成器,因为它们不能被轻易地传递给子进程。
shuffle: 是否在每轮迭代之前打乱 batch 的顺序。 只能与 Sequence (keras.utils.Sequence) 实例同用。
initial_epoch: 开始训练的轮次(有助于恢复之前的训练)。
补充知识:Keras中fit_generator 的多个分支输入时,需注意generator的格式 以及 输入序列的顺序
需要注意迭代器 yeild返回不能是[x1,x2],y 这样,而是要完整的字典格式的:
yield ({'input_1': x1, 'input_2': x2}, {'output': y})
这也不算坑 追进去 fit_generator也能看到示例
def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag=True):
ylen = len(y_train)
loopcount = ylen // batch_size
i=-1
while True:
if randomFlag:
i = random.randint(0,loopcount-1)
else:
i=i+1
i=i%loopcount
yield ({'lstmInput': x_train[i*batch_size:(i+1)*batch_size],
'bgInput': x_train2[i*batch_size:(i+1)*batch_size]},
{'prediction': y_train[i*batch_size:(i+1)*batch_size]})
ps: 因为要是tuple yield后的括号不能省
需注意的坑1是,validation data中如果用【】组成数组进行输入,是要按顺序的,按编译model前的设置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中数组的顺序来编译
需注意的坑2是,多输入input时,以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm结合时,会提示不匹配。
history=model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2),
steps_per_epoch=len(trainX)//batchSize,
validation_data=([testX,testX2],testY),
epochs=epochs,
callbacks=[tensorboard,checkpoint],initial_epoch=0,verbose=1) # Fit the LSTM network/拟合LSTM网络
来源:https://blog.csdn.net/zhangpeterx/article/details/90900118
猜你喜欢
- 下面是用python写的,使用lxml来做html分析,从网上看到的,说是分析速度最快的哦,不过没有验证过。好了,上代码。 import u
- 本文实例讲述了python读写二进制文件的方法。分享给大家供大家参考。具体如下:初学python,现在要读一个二进制文件,查找doc只发现
- 背景:大约有3K家商家需要重新确认信息并签订合同。合同是统一的Word版本。每个供应商需要修改合同内的金额部分。人工处理方式需要每个复制粘贴
- 本文实例讲述了python自动zip压缩目录的方法。分享给大家供大家参考。具体实现方法如下:这段代码来压缩数据库备份文件,没有使用pytho
- 许多服务器管理员都知道,MySQL数据库管理系统(RDBMS)是高度灵活的软件块,带有范围广阔的启动选项,可以用来修改相关行为。然而,大部分
- 一、函数初识1、定义:将一组语句的集合通过一个名字(函数名)封装起来,要想执行这个函数,只需调用其函数名即可。2、好处:代码重用;保持一致性
- 本文利用Python3启动简单的HTTP服务器,以实现在同一网络中共享本地文件。启动HTTP服务器打开终端,转入目标文件所在文件夹,键入以下
- JavaScript 有三种弹窗 Alert (只有确定按钮), Confirmation (确定,取消等按钮), Prompt (有输入对
- 这只是个asp小技巧类的东西,它虽然适合在每个不同文件名里调用这个函数,但是也是有前提的,下面让我们来仔细看看其中的原委。 &n
- 为什么要模拟登录有些网站是需要登录之后才能访问的,即便是同一个网站,在用户登录前后页面所展示的内容也可能会大不相同,例如,未登录时访问Git
- vbscript脚本中,fso对象CreateTextFile方法调用时可能会报“无效的过程调用或参数”错误,在使用ASP生成静态页面时,如
- 视图代码lis = []#设置一个空列表用来存放发送的验证码,用来验证def yzm1(): res1 = &qu
- 破解百度翻译翻译是一件麻烦的事情,如果可以写一个爬虫程序直接爬取百度翻译的翻译结果就好了,可当我打开百度翻译的页面,输入要翻译的词时突然发现
- Mysql Explain 详解一.语法explain < table_name >例如: explain select * f
- 本文实例讲述了使用symfony命令创建项目的方法。分享给大家供大家参考,具体如下:概况这一章节描述一个Symfony项目的合理结构框架,并
- 又是一年春来到,看各大网站的新年Logo也成为了我们必不可少的新年餐点,为此,我们特别整理了部分网站的新年Logo秀,如果你看到了更加有意思
- 我不知道没有他们我该如何生活我编写Python已有5年以上了,我的工具集通常变得越来越小,而不是越来越大。 许多工具不是必需的或无用的,而其
- Python list内置sort()方法用来排序,也可以用python内置的全局sorted()方法来对可迭代的序列排序生成新的序列。1)
- php文件 <?php class xpathExtension{ public static function getNodes($
- Ruby 是一门通用的语言,不仅仅是一门应用于WEB开发的语言,但 Ruby 在WEB应用及WEB工具中的开发是最常见的。使用Ruby您不仅