Kears 使用:通过回调函数保存最佳准确率下的模型操作
作者:wardenjohn 发布时间:2023-02-24 12:36:56
1:首先,我给我的MixTest文件夹里面分好了类的图片进行重命名(因为分类的时候没有注意导致命名有点不好)
def load_data(path):
Rename the picture [a tool]
for eachone in os.listdir(path):
newname = eachone[7:]
os.rename(path+"\\"+eachone,path+"\\"+newname)
但是需要注意的是:我们按照类重命名了以后,系统其实会按照图片来排序。这个时候你会看到同一个类的被排序在了一块。这个时候你不要慌张,其实这个顺序是完全不用担心的。我们只是需要得到网络对某一个图片的输出是怎么样的判断标签。这个顺序对网络计算其权重完全是没有任何的影响的
2:我在Keras中使用InceptionV3这个模型进行训练,训练模型的过程啥的我在这里就不详细说了(毕竟这个东西有点像随记那样的东西)
我们在Keras的模型里面是可以通过
H.history["val_acc"]
H.history["val_loss"]
来的得到历史交叉准确率这样的指标
3:
对于每个epoch,我们都会计算一次val_acc和val_loss,我很希望保留下我最高的val_acc的模型,那该怎么办呢?
这个时候我就会使用keras的callback函数
H = model.fit_generator(train_datagen.flow(X_train, Y_train, batch_size=batchsize),
validation_data=(X_test, Y_test), steps_per_epoch=(X_train.shape[0]) // batchsize,
epochs=epoch, verbose=1, callbacks=[tb(log_dir='E:\John\log'),
save_function])
上面的参数先查查文档把。这里我就说说我的callbacks
callbacks=[tb(log_dir = 'E\John\log')]
这个是使用tensorboard来可视化训练过程的,后面是tensorboard的log输出文件夹的路径,在网络训练的时候,相对应的训练的状态就会保存在这个文件夹下
打开终端,输入
tensorboard --log_dir <your name of the log dir> --port <the port for tensorboard>
然后输入终端指示的网址在浏览器中打开,就可以在tensorboard中看到你训练的状态了
save_function:
这是一个类的实例化:
class Save(keras.callbacks.Callback):
def __init__(self):
self.max_acc = 0.0
def on_epoch_begin(self, epoch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
self.val_acc = logs["val_acc"]
if epoch != 0:
if self.val_acc > self.max_acc and self.val_acc > 0.8:
model.save("kears_model_"+str(epoch)+ "_acc="+str(self.val_acc)+".h5")
self.max_acc = self.val_acc
save_function = Save()
这里继承了kears.callbacks.Callback
看看on_epoch_end:
在这个epoch结束的时候,我会得到它的val_acc
当这个val_acc为历史最大值的时候,我就保存这个模型
在训练结束以后,你就挑出acc最大的就好啦(当然,你可以命名为一样的,最后的到的模型就不用挑了,直接就是acc最大的模型了)
补充知识:Keras回调函数Callbacks使用详解及训练过程可视化
介绍
内容参考了keras中文文档
回调函数Callbacks
回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数。
【Tips】虽然我们称之为回调“函数”,但事实上Keras的回调函数是一个类,回调函数只是习惯性称呼
keras.callbacks.Callback()
这是回调函数的抽象类,定义新的回调函数必须继承自该类
类属性:
params:字典,训练参数集(如信息显示方法verbosity,batch大小,epoch数)
model:keras.models.Model对象,为正在训练的模型的引用
回调函数以字典logs为参数,该字典包含了一系列与当前batch或epoch相关的信息。
目前,模型的.fit()中有下列参数会被记录到logs中:
在每个epoch的结尾处(on_epoch_end),logs将包含训练的正确率和误差,acc和loss,如果指定了验证集,还会包含验证集正确率和误差val_acc)和val_loss,val_acc还额外需要在.compile中启用metrics=[‘accuracy']。
在每个batch的开始处(on_batch_begin):logs包含size,即当前batch的样本数
在每个batch的结尾处(on_batch_end):logs包含loss,若启用accuracy则还包含acc
from keras.callbacks import Callback
功能
History(训练可视化)
keras.callbacks.History()
该回调函数在Keras模型上会被自动调用,History对象即为fit方法的返回值,可以使用history中的存储的acc和loss数据对训练过程进行可视化画图,代码样例如下:
history=model.fit(X_train, Y_train, validation_data=(X_test,Y_test),
batch_size=16, epochs=20)
##或者
#history=model.fit(X_train,y_train,epochs=40,callbacks=callbacks, batch_size=32,validation_data=(X_test,y_test))
fig1, ax_acc = plt.subplots()
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Model - Accuracy')
plt.legend(['Training', 'Validation'], loc='lower right')
plt.show()
fig2, ax_loss = plt.subplots()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Model- Loss')
plt.legend(['Training', 'Validation'], loc='upper right')
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.show()
EarlyStopping
keras.callbacks.EarlyStopping(monitor='val_loss', patience=0, verbose=0, mode='auto')
当监测值不再改善时,该回调函数将中止训练
参数
monitor:需要监视的量
patience:当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。
verbose:信息展示模式
verbose = 0 为不在标准输出流输出日志信息
verbose = 1 为输出进度条记录
verbose = 2 为每个epoch输出一行记录
默认为 1
mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。
ModelCheckpoint
该回调函数将在每个epoch后保存模型到filepath
filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_end的logs关键字所填入
例如,filepath若为weights.{epoch:02d-{val_loss:.2f}}.hdf5,则会生成对应epoch和验证集loss的多个文件。
参数
filename:字符串,保存模型的路径
monitor:需要监视的值
verbose:信息展示模式,0或1
save_best_only:当设置为True时,将只保存在验证集上性能最好的模型
mode:‘auto',‘min',‘max'之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
period:CheckPoint之间的间隔的epoch数
Callbacks中可以同时使用多个以上两个功能,举例如下
callbacks = [EarlyStopping(monitor='val_loss', patience=8),
ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]
history=model.fit(X_train, y_train,epochs=40,callbacks=callbacks, batch_size=32,validation_data=(X_test,y_test))
在样例中,EarlyStopping设置衡量标注为val_loss,如果其连续4次没有下降就提前停止 ,ModelCheckpoint设置衡量标准为val_loss,设置只保存最佳模型,保存路径为best——model.h5
ReduceLROnPlateau
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0)
当评价指标不在提升时,减少学习率
当学习停滞时,减少2倍或10倍的学习率常常能获得较好的效果。该回调函数检测指标的情况,如果在patience个epoch中看不到模型性能提升,则减少学习率
参数
monitor:被监测的量 factor:每次减少学习率的因子,学习率将以lr = lr*factor的形式被减少
patience:当patience个epoch过去而模型性能不提升时,学习率减少的动作会被触发
mode:‘auto',‘min',‘max'之一,在min模式下,如果检测值触发学习率减少。在max模式下,当检测值不再上升则触发学习率减少。
epsilon:阈值,用来确定是否进入检测值的“平原区”
cooldown:学习率减少后,会经过cooldown个epoch才重新进行正常操作 min_lr:学习率的下限
使用样例如下:
callbacks_test = [
keras.callbacks.ReduceLROnPlateau(
#以val_loss作为衡量标准
monitor='val_loss',
# 学习率乘以factor
factor=0.1,
# It will get triggered after the validation loss has stopped improving
# 当被检测的衡量标准经过几次没有改善后就减小学习率
patience=10,
)
]
model.fit(x, y,epochs=20,batch_size=16,
callbacks=callbacks_test,
validation_data=(x_val, y_val))
CSVLogger
keras.callbacks.CSVLogger(filename, separator=',', append=False)
将epoch的训练结果保存在csv文件中,支持所有可被转换为string的值,包括1D的可迭代数值如np.ndarray.
参数
fiename:保存的csv文件名,如run/log.csv
separator:字符串,csv分隔符
append:默认为False,为True时csv文件如果存在则继续写入,为False时总是覆盖csv文件
来源:https://blog.csdn.net/wardenjohn/article/details/84030248
猜你喜欢
- 只想回答一个问题: 当编译器要读取obj.field时, 发生了什么?看似简单的属性访问, 其过程还蛮曲折的. 总共有以下几个step: 1
- 题目1、 请输入一个整数 , 若该数是偶数 , 输出 “ 是偶数” ”
- 学习编写简练、优化的CSS需要大量的实践和一种不自觉的强迫性清洁的渴望。然而让你的CSS保持整洁并不仅仅是你对清洁的疯狂的心理需求,尤其对于
- 前言:话说,我一直没能理解美工究竟是什么(这是一篇投稿)。因为要求确实很多。目前,我只能理解成,是前端开发+页面设计+用户体验设计的全能手。
- 安 * oostpython调用C/C++的方法有很多,本文使用boost.python。考虑到后期有好多在boost上的开发工作,所以boo
- 本文python代码实现的是最小二乘法线性拟合,并且包含自己造的轮子与别人造的轮子的结果比较。问题:对直线附近的带有噪声的数据进行线性拟合,
- raw# row方法:(掺杂着原生sql和orm来执行的操作)res = CookBook.objects.raw('select
- 1.安装Pillowpip install Pillow2.安装tesseract-ocrgithub地址: https://gi
- 1、这里只是简单介绍一下Django的view如何跟js进行交互,首先,进入用户明细的时候会进入一个页面,叫用户信息表,里面包含了用户学习的
- 如何制作一个搜索引擎链接程序?多收集几个网站的,然后我们引用它到自己的页面中。接下来,我们要创建页面用于搜索:<center>&
- gojson是快速解析json数据的一个golang包,你使用它可以快速的查找json内的数据安装 go get github.com/wi
- 前言matplotlib是Python中的一个第三方库。主要用于开发2D图表,以渐进式、交互式的方式实现数据可视化,可以更直观的呈现数据,使
- DTD实际上可以看作一个或多个XML文件的模板,这些XML文件中的元素、元素的属性、元素的排列方式/顺序、元素能够包含的内容等,都必须符合D
- 本文实例讲述了PHP面向对象程序设计子类扩展父类(子类重新载入父类)操作。分享给大家供大家参考,具体如下:在PHP中,会遇到这样的情况,子类
- 弄个随机数的东西,直接从网上找了一个现成的,简单看了两眼,感觉算法应该是对的,但今天测试下来,是不对的;网上大多数人用的写法是这样的:fun
- PyType_Type和PyBaseObject_TypePyObject和PyTypeObject内容的最后指出下图中对实例对象和类型对象
- 下面给大家分享Python爬虫后获取重定向url的两种方法,具体内容如下所示;方法(一)# 获得重定向url from urllib imp
- 本文实例讲述了Python实现列表转换成字典数据结构的方法。分享给大家供大家参考,具体如下:'''[ {
- Python 多进程默认不能共享全局变量主进程与子进程是并发执行的,进程之间默认是不能共享全局变量的(子进程不能改变主进程中全局变量的值)。
- 有时候我们用的一些pdf资料是没有目录的,这样找寻我们想到的东西比较麻烦。本篇文章就为大家带来python来生成pdf目录书签的方法。首先,