keras做CNN的训练误差loss的下降操作
作者:fitzgerald0 发布时间:2023-09-03 07:41:07
采用二值判断如果确认是噪声,用该点上面一个灰度进行替换。
噪声点处理:对原点周围的八个点进行扫描,比较。当该点像素值与周围8个点的值小于N时,此点为噪点 。
处理后的文件大小只有原文件小的三分之一,前后的图片内容肉眼几乎无法察觉。
但是这样处理后图片放入CNN中在其他条件不变的情况下,模型loss无法下降,二分类图片,loss一直在8-9之间。准确率维持在0.5,同时,测试集的训练误差持续下降,但是准确率也在0.5徘徊。大概真是需要误差,让优化方法从局部最优跳出来。
使用的activation function是relu,full connection layer是softmax分类函数,优化方法为RMsprop
难到是需要加入噪音更好,CNN中加入高斯噪音不是让模型更稳健的吗?还有让模型跳出局部最优的好处,方便训练。
原意:降噪的目的是因为这批数据是样本较少,用复印机 扫面出来的图片,想着放入更干净的数据,模型更容易学习到本质特征。
结果事与愿违,但是在keras中是可以加入noise的,比如加入高斯噪音
form keras.layers.noise import GaussianNoise
我在全连接层中加入
model.add(GaussianNoise(0.125))
后来查看了BatchNormalization的作用,发现在这个大杀器之后,好像很少有人用到初始化和其他的tricks,就可以让模型表现的很好。
在第一层的Maxpooling后面加上,model.add(BatchNormalization()),效果非常显著,第一次epoch的loss值只有0.63,acc也迅速上升,不会出现之前的卡在8.354一直不动,哪怕更换 leraning rate和使用Adagrad,都是一样的,如果前面的5个epoch完,还是没有太大的变化,后面几乎不会收敛。
1,leraning rate的设置
#导入模块,以rmsprop为例
from keras.optimizers import rmsprop
rmsprop=rmsprop(lr=0.1)#只是更改了学习率,其他的参数没有更改,默认学习率是0.001
2.BatchNormalization()的设置
from keras.layers.normalization import BatchNormalization
#网上不少人说,批规范化 加在输入层的激活函数(层)的前面
model.add(BatchNormalization())
也有看到每一个隐藏层的激活函数前面全部加上BN的,但是我这个实验中,效果很差。
3.在输入数据的时候,依然加上train_x = data/255.0,对像素矩阵的取值放小到0-1之间,否则训练将很艰难。
其实在我自己的实验中,后来调整成:
train_x-= np.mean(train_x, axis = 0)
发现效果更好
4.如果第一次的epoch的loss在个位数,则很可能需要返回去重新构建模型,加入更多的trick,如果最后的loss值依然没有达到小数,则也可能是难于训练,也需要加入其他的技巧。或者模型搭建的有问题,需要慎重检查。
5. 建议使用网格搜索,从最重要的参数开始,搭建一个简单的模型,然后取合理的超参数,逐一进行。
6 .也可以在卷积层中加正则化,比如:
C1 = Convolution2D(8 3, 3, border_mode='valid', init='he_uniform', activation='relu',W_regularizer=l2(regularizer_params))
7.有看到在kaggle中使用集成cnn的,分类错误率确实有下降。
8 使用ReduceLROnPlateau 对学习率进行衰减,当下降很慢时,学习率自动调整,可以起到一部分作用,
我在模型中使用的是RMSprop ,RMSprop本身带有学习率的自动调整,但是,我加上ReduceLROnPlateau ,依然可以看到学习率变化很慢时,设置的这个ReduceLROnPlateau 有调整。
9 用数据增强的时候,也需要小心,图片调整的幅度等均会对模型的正确率有影响。
10,对3个颜色的图像转换为gray以后,分类准确率稳定在 0.5左右,几乎就是废掉了,说明图像的像素对于模型的影响巨大,后来了解到有“图像超分辨率重建Super-Resolution”其实是可以对图像做像素的分辨率更高。当然也是可以手工用PS进行插值等修图。查了下,像mnist这样的数据集都是经过处理后才放入模型中的,所以,不能完全指望着CNN卷积池化就把所有的问题都解决掉,尽管图像分类和识别正在像CNN转移。
keras遇到的坑(可能是水平的问题,总之有困惑)
(1) 多次运行会在上一次运行过的数据上起作用,比如,
train_x , val_x , train_y, val_y = train_test_split(train_x, train_y, test_size=0.1, random_state=random_seed)
如果多次运行,则1000个数据,900个训练集的,下一次变成,900*0.9=810个数据,同时,还发现,
train_y = to_categorical(label, num_classes =2),这里也可能出现问题,比如,二分类,在第一次运行后是,2行
第二次运行就变成4行
(2) 在做交叉验证时
新版本epoch的写法是epochs=
estimator = KerasClassifier(build_fn=baseline_model, epochs=20, batch_size=32, verbose=2)
如果用成下面老版本,则n_epoch无法读取,运行的时候,默认的是1所以我定义的 n_epoch=20是失效。
estimator = KerasClassifier(build_fn=baseline_model, n_epoch=20, batch_size=32, verbose=2)
补充知识:keras中loss与val_loss的关系
loss是训练集的损失值,val_loss是测试集的损失值
以下是loss与val_loss的变化反映出训练走向的规律总结:
train loss 不断下降,test loss不断下降,说明网络仍在学习;(最好的)
train loss 不断下降,test loss趋于不变,说明网络过拟合;(max pool或者正则化)
train loss 趋于不变,test loss不断下降,说明数据集100%有问题;(检查dataset)
train loss 趋于不变,test loss趋于不变,说明学习遇到瓶颈,需要减小学习率或批量数目;(减少学习率)
train loss 不断上升,test loss不断上升,说明网络结构设计不当,训练超参数设置不当,数据集经过清洗等问题。(最不好的情况)
来源:https://blog.csdn.net/fitzgerald0/article/details/79002018


猜你喜欢
- 我们常常会用到PHP过滤一些标签的功能,比如过滤链接标签、过滤script标签等等,下面就介绍一下PHP过滤常用标签的正则表达式代码:$st
- 如果你是一位ASP爱好者,你一定想过ASP的执行效率如何?大家都知道ASP效率和CGI的比,在访问量少的时候,它们是不相上下的,有时可能CG
- keras非常方便。不解释,直接上实例。数据格式如下:序号 天气 是否周末 是否有促销 销量1 坏 &n
- 本文实例讲述了Python连接phoenix的方法。分享给大家供大家参考,具体如下:phoenix是由saleforce.com开源的一个项
- 1.先检查系统是否装有mysqlrpm -qa | grep mysql2.下载mysql的repo源(5.7)wget -i -c htt
- 本文实例讲述了js实现的全国省市二级联动下拉选择菜单。分享给大家供大家参考。具体如下:运行效果截图如下:具体代码如下:<!DOCTYP
- 优先级队列概述队列,是数据结构中实现先进先出策略的一种数据结构。而优先队列则是带有优先级的队列,即先按优先级分类,然后相同优先级的再 进行排
- 通过navicat客户端修改datetime默认值时,遇到了问题。数据库表字段类型datetime,原来默认为NULL,当通过界面将默认值设
- 先看Pytorch中的卷积class torch.nn.Conv2d(in_channels, out_channels, kernel_s
- 用于绘制直线的line函数;用于绘制椭圆的ellipse函数;用于绘制矩形的rectangle函数;用于绘制圆的circle函数;用于绘制填
- 前言最近又多了不少朋友关注,先在这里谢谢大家。关注我的朋友大多数都是大学生,而且我简单看了一下,低年级的大学生居多,大多数都是为了完成课程设
- 本文实例讲述了Python学习笔记之Break和Continue用法。分享给大家供大家参考,具体如下:Python 中的Break 和 Co
- ASPError Object 这个新增的,内置与ASP 3.0中的对象提供了一个以往版本中没有的专门用来处理错误的对象,这样,我们来操纵错
- 很早很早的时候,computer这个东西习惯于被称之为计算机,因为它的主要功能是完成一些科学计算的东西,我记得自己鼓捣它的时候,就是计算,根
- TensorFlow是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(Nodes)在图中表示数学操作,
- 以下函数可用于替换php内置的is_writable函数//可用于替换php内置的is_writable函数function isWrita
- 每个 ndarray 都有一个关联的数据类型 (dtype) 对象。这个数据类型对象(dtype)告诉我们数组的布局。这意味着它为我们提供了
- 目录简述:实战案例:简述:关于敏感词过滤可以看成是一种文本反垃圾算法,例如 题目:敏感词文本文件 filtered_words.t
- 楔子由于之前电脑上安装的MySQL版本是比较老的了,大概是5.1的版本,不支持JSON字段功能。而最新开发部门开发的的编辑器产品,使用到了J
- 基于 Snapchat 的增强现实胡子挂件融合第一个项目中,我们将在检测到的脸上覆盖了一个小胡子。我们可以使用从摄像头捕获的连续视频帧,也可