keras回调函数的使用
作者:辛勤的小码农^_^ 发布时间:2022-08-22 11:42:41
标签:keras,回调函数
回调函数
回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用
可以访问关于模型状态与模型性能的所有可用数据
模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。
提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。
在训练过程中动态调节某些参数值:比如调节优化器的学习率。
在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。
fit()方法中使用callbacks参数
# 这里有两个callback函数:早停和模型检查点
callbacks_list=[
keras.callbacks.EarlyStopping(
monitor="val_accuracy",#监控指标
patience=2 #两轮内不再改善中断训练
),
keras.callbacks.ModelCheckpoint(
filepath="checkpoint_path",
monitor="val_loss",
save_best_only=True
)
]
#模型获取
model=get_minist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
model.fit(train_images,train_labels,
epochs=10,callbacks=callbacks_list, #该参数使用回调函数
validation_data=(val_images,val_labels))
test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标
predictions=model.predict(test_images)#计算模型在新数据上的分类概率
模型的保存和加载
#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。
#重新加载模型
model_new=keras.models.load_model("checkpoint_path.keras")
通过对Callback类子类化来创建自定义回调函数
on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用
from matplotlib import pyplot as plt
# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs):
self.per_batch_losses = []
def on_batch_end(self, batch, logs):
self.per_batch_losses.append(logs.get("loss"))
def on_epoch_end(self, epoch, logs):
plt.clf()
plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
label="Training loss for each batch")
plt.xlabel(f"Batch (epoch {epoch})")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"plot_at_epoch_{epoch}")
self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()
model.compile(optimizer="rmsprop",
loss="sparse_categorical_crossentropy",
metrics=["accuracy"])
model.fit(train_images, train_labels,
epochs=10,
callbacks=[LossHistory()],
validation_data=(val_images, val_labels))
【其他】模型的定义 和 数据加载
def get_minist_model():
inputs=keras.Input(shape=(28*28,))
features=layers.Dense(512,activation="relu")(inputs)
features=layers.Dropout(0.5)(features)
outputs=layers.Dense(10,activation="softmax")(features)
model=keras.Model(inputs,outputs)
return model
#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]
来源:https://blog.csdn.net/qq_43787439/article/details/129438485


猜你喜欢
- 先附上官方文档:https://pandas.pydata.org/pandas-docs/stable/reference/api/pan
- 1.用CSS实现布局让我们一起来做一个页面,首先,我们需要一个布局。请使用CSS控制3个div,实现如下图的布局。考察应试者的基本布局知识—
- 在日常运维中,如果涉及到用户管理,就一定会用到给用户设置密码的工作,其实吧,平时脑子里觉得设置个密码没什么,但要真让你随手敲一个12位带特殊
- 官方说明链接:https://intellij-support.jetbrains.com/hc/en-us/community/posts
- 大家好,我是只谈技术不剪发的 Tony 老师。这次我们来介绍一个 MySQL 8.0 增加的新功能:检查约束(CHECK )。SQL 中的检
- 1、说明GIL规定一个Python解释程序只能同时由一个线程控制。在CPU限制类型和多线程代码中,GIL是一个性能瓶颈。GIL使Python
- 在windows+iis服务器上运行asp程序可能会出现数据库无法更新的情况,具体错误信息可能为: 1、Microsoft JET Data
- 本文介绍了一系列安装教程,具体如下1.安装Python版本选择是3.5.1,因为网上有些深度学习实例用的就是这个版本,跟他们一样的话可以避免
- 完成asp语言对XML文档中指定节点文本的增加、删除、修改、查看 <% '-------------------
- nonzero函数返回非零元素的目录。返回值为元组, 两个值分别为两个维度, 包含了相应维度上非零元素的目录值。 import
- 前言之前简单学习过python爬虫基础知识,并且用过scrapy框架爬取数据,都是直接能用xpath定位到目标区域然后爬取。可这次碰到的需求
- 本文实例讲述了python实现生成Word、docx文件的方法。分享给大家供大家参考,具体如下:http://python-docx.rea
- 以前没见过这个效果,滚动纵向滚动条看看效果就明白了这样的效果,广告商应该比较喜欢。<!DOCTYPE html PUBLIC &quo
- 这篇文章主要介绍了Python如何计算语句执行时间,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可
- 对于网站开发者来说,对展示内容增加一个滑动或者是轮播效果的是非常常见的需求。收费和免费的轮播插件多的是不胜枚举。其中很 多提供很多有用的配置
- 背景:因为工作需要,公司给每个员工都分配了一个邮箱 公司的各种业务都通过邮箱发送。虽然给每个员工的电脑都设置pop3登录但是他们的程序设定有
- 要写出一个五子棋游戏,我们最先要解决的,就是如何下子,如何判断已经五子连珠,而不是如何绘制画面,因此我们先确定棋盘五子棋采用15*15的棋盘
- <ul> <li> <input type="radio" name="radi
- SQL2005 Express 没了「企业管理器」和「查询分析器」 SQL2005 分五个版本,如下所列: 1.Enterprise(企业版
- pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的