一小时学会TensorFlow2之大幅提高模型准确率
作者:我是小白呀 发布时间:2021-07-25 16:25:20
过拟合
当训练集的的准确率很高, 但是测试集的准确率很差的时候就, 我们就遇到了过拟合 (Overfitting) 的问题. 如图:
过拟合产生的一大原因是因为模型过于复杂. 下面我们将通过讲述 5 种不同的方法来解决过拟合的问题, 从而提高模型准确度.
Regulation
Regulation 可以帮助我们通过约束要优化的参数来防止过拟合.
公式
未加入 regulation 的损失:
加入 regulation 的损失:
λ 和 lr (learning rate) 类似. 如果 λ 的值越大, regularion 的力度也就越强, 权重的值也就越小.
例子
添加了 l2 regulation 的网络:
network = tf.keras.Sequential([
tf.keras.layers.Dense(256, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),
tf.keras.layers.Dense(128, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),
tf.keras.layers.Dense(64, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),
tf.keras.layers.Dense(32, kernel_regularizer=tf.keras.regularizers.l2(0.001), activation=tf.nn.relu),
tf.keras.layers.Dense(10)
])
动量
动量 (Momentum) 是指运动物体的租用效果. 在梯度下降的过程中, 通过在优化器中加入动量, 我们可以减少摆动从而达到更优的效果.
未添加动量:
添加动量:
公式
未加动量的权重更新:
w: 权重 (weight)
k: 迭代的次数
α: 学习率 (learning rate)
∇f(): 微分
添加动量的权重更新:
β: 动量权重
z: 历史微分
例子
添加了动量的优化器:
optimizer = tf.keras.optimizers.SGD(learning_rate=0.02, momentum=0.9)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.02, momentum=0.9)
注: Adam 优化器默认已经添加动量, 所以无需自行添加.
学习率递减
简单的来说, 如果学习率越大, 我们训练的速度就越大, 但找到最优解的概率也就越小. 反之, 学习率越小, 训练的速度就越慢, 但找到最优解的概率就越大.
过程
我们可以在训练初期把学习率调的稍大一些, 使得网络迅速收敛. 在训练后期学习率小一些, 使得我们能得到更好的收敛以获得最优解. 如图:
例子
learning_rate = 0.2 # 学习率
optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9) # 优化器
# 迭代
for epoch in range(iteration_num):
optimizer.learninig_rate = learning_rate * (100 - epoch) / 100 # 学习率递减
Early Stopping
之前我们提到过, 当训练集的准确率仍在提升, 但是测试集的准确率反而下降的时候, 我们就遇到了过拟合 (overfitting) 的问题.
Early Stopping 可以帮助我们在测试集的准确率下降的时候停止训练, 从而避免继续训练导致的过拟合问题.
Dropout
Learning less to learn better
Dropout 会在每个训练批次中忽略掉一部分的特征, 从而减少过拟合的现象.
dropout, 通过强迫神经元, 和随机跳出来的其他神经元共同工作, 达到好的效果. 消除减弱神经元节点间的联合适应性, 增强了泛化能力.
例子:
network = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.5), # 忽略一半
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.5), # 忽略一半
tf.keras.layers.Dense(64, activation=tf.nn.relu),
tf.keras.layers.Dropout(0.5), # 忽略一半
tf.keras.layers.Dense(32, activation=tf.nn.relu),
tf.keras.layers.Dense(10)
])
来源:https://blog.csdn.net/weixin_46274168/article/details/117986352


猜你喜欢
- 升级背景:为了解决mysql低版本的漏洞,从mysql5.5升级到了8.0.11版本,再次升级到了8.0.17版本(从版本是2019.7.2
- 在使用mysql运行某些语句时,会因数据量太大而导致死锁,没有反映。这个时候,就需要kill掉某个正在消耗资源的query语句即可,KILL
- 最近需要各种转格式,这里对相关代码作一个记录,方便日后查询。xlsx文件转csv文件import xlrdimport csvdef xls
- 在Python所有的数据结构中,list具有重要地位,并且非常的方便,这篇文章主要是讲解list列表的高级应用,基础知识可以查看博客。 此文
- 网页上搜索 “python绘制国际象棋棋盘”,索引结果均为调用 turtle 库绘制棋盘结果;为了填充使用 python PIL 图像处理库
- 数据类型是一种值的集合以及定义在这种值上的一组操作。一切语言的基础都是数据结构,所以打好基础对于后面的学习会有百利而无一害的作用。pytho
- 为了方便各位朋友,本文收集了一些对Web开发人员非常有用的手册,记得推荐一下哦。HTML 速查手册HTML/XTML in one page
- 解决python -v报错问题的方法:在cmd命令行中输入“python -v”报错是因为没有将python的安装路径添加到系统环境变量pa
- 前言JS 中 GBK 编码转字符串是非常简单的,直接调用 TextDecoder 即可:const gbkBuf = n
- 原始数据在这里1.观察数据首先,用Pandas打开数据,并进行观察。import numpy import pandas as pdimpo
- 如下所示:#coding utf-8a=0.001 #定义收敛步长xd=1 #定义寻找步
- 以下以CentOS 7.2为例,安装php的运行环境,首先打开php官网http://php.net/点击导航栏的Downloads进入下载
- 一、序言本文承接[Mybatis缓存体系探究],提供基于MybatisPlus技术可用于生产环境下的二级缓存解决方案。1、前置条件掌握MyB
- 表结构很简单CREATE TABLE `oplogs` (`id` int(10) unsigned NOT NULL AUTO_INCRE
- 什么是 Goroutinegoroutine 是 Go 并行设计的核心。goroutine 说到底其实就是协程,它比线程更小,十几个 gor
- 引言“ 这是MySQL系列笔记的第二篇,文章内容均为本人通过实践及查阅资料相关整理所得,可用作新手入门指南,或
- 在数据库中,字符型的数据是最多的,可以占到整个数据库的80%以上。为此正确处理字符型的数据,对于提高数据库的性能有很大的作用。在字符型数据中
- 前面提到了银行转账这个场景,展示了一个比较耗时的转账操作。这篇继续转帐,下面展示一段程序,多个线程的操作都更改了amount变量导致运行结果
- 1. sys 模块Python 中的 sys 模块具有 argv 功能。当通过终端触发 main.py 的执行时,此功能将返回提供给 mai
- 方法1: 代码如下:truncate table TableName 删除表中的所有的数据的同时,将自动增长清零。 如果有外键参考这个表,这