浅谈keras 模型用于预测时的注意事项
作者:机器AI 发布时间:2022-10-16 13:23:04
为什么训练误差比测试误差高很多?
一个Keras的模型有两个模式:训练模式和测试模式。一些正则机制,如Dropout,L1/L2正则项在测试模式下将不被启用。
另外,训练误差是训练数据每个batch的误差的平均。在训练过程中,每个epoch起始时的batch的误差要大一些,而后面的batch的误差要小一些。另一方面,每个epoch结束时计算的测试误差是由模型在epoch结束时的状态决定的,这时候的网络将产生较小的误差。
【Tips】可以通过定义回调函数将每个epoch的训练误差和测试误差并作图,如果训练误差曲线和测试误差曲线之间有很大的空隙,说明你的模型可能有过拟合的问题。当然,这个问题与Keras无关。
在keras中文文档中指出了这一误区,笔者认为产生这一问题的原因在于网络实现的机制。即dropout层有前向实现和反向实现两种方式,这就决定了概率p是在训练时候设置还是测试的时候进行设置
利用预训练的权值进行Fine tune时的注意事项:
不能把自己添加的层进行将随机初始化后直接连接到前面预训练后的网络层
in order to perform fine-tuning, all layers should start with properly trained weights: for instance you should not slap a randomly initialized fully-connected network on top of a pre-trained convolutional base. This is because the large gradient updates triggered by the randomly initialized weights would wreck the learned weights in the convolutional base. In our case this is why we first train the top-level classifier, and only then start fine-tuning convolutional weights alongside it.
we choose to only fine-tune the last convolutional block rather than the entire network in order to prevent overfitting, since the entire network would have a very large entropic capacity and thus a strong tendency to overfit. The features learned by low-level convolutional blocks are more general, less abstract than those found higher-up, so it is sensible to keep the first few blocks fixed (more general features) and only fine-tune the last one (more specialized features).
fine-tuning should be done with a very slow learning rate, and typically with the SGD optimizer rather than an adaptative learning rate optimizer such as RMSProp. This is to make sure that the magnitude of the updates stays very small, so as not to wreck the previously learned features.
补充知识:keras框架中用keras.models.Model做的时候预测数据不是标签的问题
我们发现,在用Sequential去搭建网络的时候,其中有predict和predict_classes两个预测函数,前一个是返回的精度,后面的是返回的具体标签。但是,在使用keras.models.Model去做的时候,就会发现,它只有一个predict函数,没有返回标签的predict_classes函数,所以,针对这个问题,我们将其改写。改写如下:
def my_predict_classes(predict_data):
if predict_data.shape[-1] > 1:
return predict_data.argmax(axis=-1)
else:
return (predict_data > 0.5).astype('int32')
# 这里省略网络搭建部分。。。。
model = Model(data_input, label_output)
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.Nadam(lr=0.002),
metrics=['accuracy'])
model.summary()
y_predict = model.predict(X_test)
y_pre = my_predict_classes(y_predict)
这样,y_pre就是具体的标签了。
来源:https://blog.csdn.net/xiaojiajia007/article/details/73771311


猜你喜欢
- 本文实例讲述了Django2 连接MySQL及model测试。分享给大家供大家参考,具体如下:参考:https://www.jb51.net
- 有时候我们想要的数据合并结果是数据的轴向连接,在pandas中这可以通过concat来实现。操作的对象通常是Series。Ipython中的
- Mysql常用显示命令1、显示当前数据库服务器中的数据库列表:mysql> SHOW DATABASES;注意:mysql库里面有MY
- 在遥感应用中,我们经常需要对某一景遥感影像中的全部像元的像素值进行平均值求取——这一操作很好实现,基
- 在python中我们可以使用speech模块让计算机进行语音输出,我们需要使用如下代码安装该模块。对于如何在终端中安装python相应模块,
- 一、前言大多数编译型语言,变量在使用前必须先声明,其中C语言更加苛刻:变量声明必须位于代码块最开始,且在任何其他语句之前。其他语言,想C++
- 根据用户的权限,展示不同的菜单页。知识点路由守卫(使用了前置守卫):根据用户角色判断要添加的路由vuex:保存动态添加的路由难点每次路由发生
- 表单内有两个提交按钮,要实现当点击不同的提交按钮时,分别进行两个不同的处理过程,在这里有实现表单多按钮提交action的处理方法分享给大家。
- 在推行系统中,时不时会有用户提出希望系统能自动推送邮件,由于手头的工具和能力有限,不少需求都借助于sql server的邮件触发来实现。步骤
- 有序列表list>>> listTest = ['ha','test','yes&
- 前言本文的主要内容是使用 cpu 版本的 tensorflor-2.1 完成对 Auto MPG 数据集的回归预测任务。获取 Auto MP
- 1、选取最适用的字段属性MySQL 可以很好的支持大数据量的存取,但是一般说来,数据库中的表越小,在它上面执行的查询也就会越快。因此,在创建
- 一、前言程序访问MySQL数据库时,当查询出来的数据量特别大时,数据库驱动把加载到的数据全部加载到内存里,就有可能会导致内存溢出(OOM)。
- 一、基本概念查找(Searching)就是根据给定的某个值,在查找表中确定一个其关键字等于给定值的数据元素(或记录)。查找表(Search
- 本文实例讲述了python实现的用于搜索文件并进行内容替换的类。分享给大家供大家参考。具体实现方法如下:#!/usr/bin/python
- 一、常见模型分类1.1、循环服务器模型循环接收客户端请求,处理请求。同一时刻只能处理一个请求,处理完毕后再处理下一个。优点:实现简单,占用资
- 本文实例为大家分享了vue实现百度搜索功能的具体代码,供大家参考,具体内容如下最终效果:Baidusearch.vue所有代码:<te
- JS如何从一个数组中随机取出一个元素或者几个元素。假如数组为var items = ['1','2',
- 前言最近需要源码部署一个项目,因此探索一下保护源码的方式,由简单到复杂主要总结为以下三大类:代码混淆:主要是改变一些函数名、变量名代码打包:
- 问题说明(环境:windows7,MySql8.0)今天安装好MySql后启动MySql服务-->启动服务都失败的就不要往下看了,自行