浅谈keras 模型用于预测时的注意事项
作者:机器AI 发布时间:2022-10-16 13:23:04
为什么训练误差比测试误差高很多?
一个Keras的模型有两个模式:训练模式和测试模式。一些正则机制,如Dropout,L1/L2正则项在测试模式下将不被启用。
另外,训练误差是训练数据每个batch的误差的平均。在训练过程中,每个epoch起始时的batch的误差要大一些,而后面的batch的误差要小一些。另一方面,每个epoch结束时计算的测试误差是由模型在epoch结束时的状态决定的,这时候的网络将产生较小的误差。
【Tips】可以通过定义回调函数将每个epoch的训练误差和测试误差并作图,如果训练误差曲线和测试误差曲线之间有很大的空隙,说明你的模型可能有过拟合的问题。当然,这个问题与Keras无关。
在keras中文文档中指出了这一误区,笔者认为产生这一问题的原因在于网络实现的机制。即dropout层有前向实现和反向实现两种方式,这就决定了概率p是在训练时候设置还是测试的时候进行设置
利用预训练的权值进行Fine tune时的注意事项:
不能把自己添加的层进行将随机初始化后直接连接到前面预训练后的网络层
in order to perform fine-tuning, all layers should start with properly trained weights: for instance you should not slap a randomly initialized fully-connected network on top of a pre-trained convolutional base. This is because the large gradient updates triggered by the randomly initialized weights would wreck the learned weights in the convolutional base. In our case this is why we first train the top-level classifier, and only then start fine-tuning convolutional weights alongside it.
we choose to only fine-tune the last convolutional block rather than the entire network in order to prevent overfitting, since the entire network would have a very large entropic capacity and thus a strong tendency to overfit. The features learned by low-level convolutional blocks are more general, less abstract than those found higher-up, so it is sensible to keep the first few blocks fixed (more general features) and only fine-tune the last one (more specialized features).
fine-tuning should be done with a very slow learning rate, and typically with the SGD optimizer rather than an adaptative learning rate optimizer such as RMSProp. This is to make sure that the magnitude of the updates stays very small, so as not to wreck the previously learned features.
补充知识:keras框架中用keras.models.Model做的时候预测数据不是标签的问题
我们发现,在用Sequential去搭建网络的时候,其中有predict和predict_classes两个预测函数,前一个是返回的精度,后面的是返回的具体标签。但是,在使用keras.models.Model去做的时候,就会发现,它只有一个predict函数,没有返回标签的predict_classes函数,所以,针对这个问题,我们将其改写。改写如下:
def my_predict_classes(predict_data):
if predict_data.shape[-1] > 1:
return predict_data.argmax(axis=-1)
else:
return (predict_data > 0.5).astype('int32')
# 这里省略网络搭建部分。。。。
model = Model(data_input, label_output)
model.compile(loss='categorical_crossentropy',
optimizer=keras.optimizers.Nadam(lr=0.002),
metrics=['accuracy'])
model.summary()
y_predict = model.predict(X_test)
y_pre = my_predict_classes(y_predict)
这样,y_pre就是具体的标签了。
来源:https://blog.csdn.net/xiaojiajia007/article/details/73771311
猜你喜欢
- 我们都知道 vue 中可以使用 modal 来实现 input 内容数据的双向绑定。小程序好像没有提供相应的方法支持,就需要我们自己写了。原
- 本文采用拉普拉斯算子计算影像的模糊程度,小于阈值的影像被认为是模糊的,从而被移动到专门存放模糊影像的文件夹。本文只使用cv2和shutil库
- 本文实例讲述了Python3.5 Pandas模块缺失值处理和层次索引。分享给大家供大家参考,具体如下:1、pandas缺失值处理impor
- # -*- coding: utf-8 -*- import numpy as npimport matplotlib.pyplot as
- 本文实例讲述了python自动化测试之从命令行运行测试用例with verbosity,分享给大家供大家参考。具体如下:实例文件recipe
- 日常运维工作中,通常是邮件报警机制,但邮件可能不被及时查看,导致问题出现得不到及时有效处理。所以想到用Python实现发短信功能,当监控到问
- 我在程序中加入了分数显示,三种特殊食物,将贪吃蛇的游戏逻辑写到了SnakeGame的类中,而不是在Snake类中。特殊食物:1.绿色:普通,
- 先看看:css中class与id的区别及应用表单的name与id其实是同一个意思,都是为了标记对象名称。它们所不同的是:name是Netsc
- 如下所示:#随机数的使用import random #导入randomrandom.randint(0,9)#制定随机数0到9i=rando
- 前两天学习了一下socket编程,在向某大神请教问题时被嫌弃了,有一种还没学会走就想跑的感觉。大神说我现在的水平应该去做一些像是操作文件、序
- #!/usr/bin/python #-*- encoding: utf-8 -*- import types class NotInteg
- 上一篇博客介绍了 如何使用Python,OpenCV上下左右(或任意组合)平移图像。这篇博客将介绍如何使用OpenCV旋转图像任意角度。并演
- 导语承载童年的纸飞机你还会叠嘛?如果你是个80后或者90后,那你应该记得小时候玩的纸飞机。叠好后,哈口仙气,飞出去,感觉棒棒哒。虽然是一个极
- 合并在numpy中合并两个arraynumpy中可以通过concatenate,参数axis=0表示在垂直方向上合并两个数组,等价于np.v
- Python中格式化format()方法详解Python中格式化输出字符串使用format()函数, 字符串即类, 可以使用方法
- 题目:轮盘分为三部分: 一等奖, 二等奖和三等奖;轮盘转的时候是随机的,如果范围在[0,0.08)之间,代表一等奖,如果范围在[0.08,0
- 有时候需要在网页中某个div载入之后,动态引入一段javascript,IE下的解决方案: newjs. onreadystatechang
- 一、引言在编写调试Python代码过程中,我们经常需要记录日志,通常我们会采用python自带的内置标准库logging,但是使用该库,配置
- 前言在编程开发中,个人觉得,只要按照规范去做,很少会出问题。刚开始学习一门技术时,的确会遇到很多的坑。踩的坑多了,这是好事,会学到更多东西,
- 1、构建合理的HTTP请求标头。HTTP的请求头是一组属性和配置信息,当您发送一个请求到网络服务器时。因为浏览器和Python爬虫发送的请求