批标准化层 tf.keras.layers.Batchnormalization()解析
作者:壮壮不太胖^QwQ 发布时间:2023-06-18 23:35:17
批标准化层 tf.keras.layers.Batchnormalization()
tf.keras.layers.Batchnormalization()
重要参数:
training
:布尔值,指示图层应在训练模式还是在推理模式下运行。training=True
:该图层将使用当前批输入的均值和方差对其输入进行标准化。training=False
:该层将使用在训练期间学习的移动统计数据的均值和方差来标准化其输入。
BatchNormalization 广泛用于 Keras 内置的许多高级卷积神经网络架构,比如 ResNet50、Inception V3 和 Xception。
BatchNormalization 层通常在卷积层或密集连接层之后使用。
批标准化的实现过程
求每一个训练批次数据的均值
求每一个训练批次数据的方差
数据进行标准化
训练参数γ,β
输出y通过γ与β的线性变换得到原来的数值
在训练的正向传播中,不会改变当前输出,只记录下γ与β。在反向传播的时候,根据求得的γ与β通过链式求导方式,求出学习速率以至改变权值。
对于预测阶段时所使用的均值和方差,其实也是来源于训练集。比如我们在模型训练时我们就记录下每个batch下的均值和方差,待训练完毕后,我们求整个训练样本的均值和方差期望值,作为我们进行预测时进行BN的的均值和方差。
批标准化的使用位置
原始论文讲在CNN中一般应作用与非线性激活函数之前,但是,实际上放在激活函数之后效果可能会更好。
# 放在非线性激活函数之前
model.add(tf.keras.layers.Conv2D(64, (3, 3)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('relu'))
# 放在激活函数之后
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
tf.keras.layers.BatchNormalization使用细节
关于keras中的BatchNormalization使用,官方文档说的足够详细。本文的目的旨在说明在BatchNormalization的使用过程中容易被忽略的细节。
在BatchNormalization的Arguments参数中有trainable属性;以及在Call arguments参数中有training。两个都是bool类型。第一次看到有两个参数的时候,我有点懵,为什么需要两个?
后来在查阅资料后发现了两者的不同作用。
1,trainable是Argument参数,类似于c++中构造函数的参数一样,是构建一个BatchNormalization层时就需要传入的,至于它的作用在下面会讲到。
2,training参数时Call argument(调用参数),是运行过程中需要传入的,用来控制模型在那个模式(train还是interfere)下运行。关于这个参数,如果使用模型调用fit()的话,是可以不给的(官方推荐是不给),因为在fit()的时候,模型会自己根据相应的阶段(是train阶段还是inference阶段)决定training值,这是由learning——phase机制实现的。
重点
关于trainable=False:如果设置trainable=False,那么这一层的BatchNormalization层就会被冻结(freeze),它的trainable weights(可训练参数)(就是gamma和beta)就不会被更新。
注意:freeze mode和inference mode是两个概念。
但是,在BatchNormalization层中,如果把某一层BatchNormalization层设置为trainable=False,那么这一层BatchNormalization层将一inference mode运行,也就是说(meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
来源:https://blog.csdn.net/weixin_46072771/article/details/108591263


猜你喜欢
- Function getIpvalue(clientIP)'得到客户端的IP转换成长整型,返回值getIpvalue&nb
- 准备1.电脑系统:win102.手机:安卓(没钱买苹果)3.需要的工具可以从官网下载https://appium.io/https://ww
- django {% url %} 模板标签使用inclusions/_archives.html...{% for date in date
- 本文实例讲述了python在windows下创建隐藏窗口子进程的方法。分享给大家供大家参考。具体实现方法如下:import subproce
- 目录1. 前言2. 实战一下2-1 进入虚拟环境,创建一个项目及 App2-2 创建模板目录并配置 set
- 学习了一天的深度学习,略有疲惫,我们用pygame搞个小游戏放松放松吧。今天我们的游戏主体是烟雨蒙蒙下彩虹雨,仿佛置身江南水乡。游戏描述我们
- 普通爬虫正常流程:数据来源分析发送请求获取数据解析数据保存数据环境介绍python 3.8pycharm 2021专业版【付费VIP完整版】
- Python 提供了多个图形开发界面的库。Tkinter就是其中之一。 Tkinter 模块(Tk 接口)是 Python 的标准 Tk G
- 重复性任务总是耗时且无聊,想一想你想要一张一张地裁剪 100 张照片或 Fetch API、纠正拼写和语法等工作,所有这些任务都很耗时,为什
- 清除视图缓存,就是清除D:\phpStudy\WWW\BCCKidV1.0\storage\framework\views\002f30b1
- 简单介绍下功能吧:使用了ASP的一个对象ServerVariables(服务器环境变量),通过这个环境变量可以获取到真正的下载地址再通过一些
- 本文实例讲述了php逐行读取txt文件写入数组的方法。分享给大家供大家参考。具体如下:假设有user.txt文件如下:user01user0
- 需要建立2个文件,一个作为客户端,一个作为服务端文件一 作为客户端client,文件二作为服务端serverudp的特点是不需要建立连接文件
- 作者:HelloGitHub-追梦人物文中所涉及的示例代码,已同步更新到 HelloGitHub-Team 仓库当博客上发布的文章越来越多时
- //设置已存在表中字段为auto_incrementALTER TABLE tablename change id id int(2) no
- asp之家注:学习javascript(js),免不了要用到打开新窗口,方法很多,总的来说是使用window.open。不同与HTML中的t
- Line plotsAxes3D.plot(xs, ys, *args, **kwargs)绘制2D或3D数据参数描述xs, ysX轴,Y轴
- 前言最近用Django写项目的时候用到了数据的传递,一窍不通,查了点资料。记录一下。水平不高,瓜不保熟。 从两方面来说:从后端传递
- 一、说明早上看到Python使用pickle进行序列化和反序列化,然后发现面临的一个获取不到返回值的框架,似乎可以通过在框架中先序列化,然后
- 需要写个js滑动展开折叠(收缩)的效果,搜索到无忧脚本的一篇贴子,稍加修改了下使其在FF也可应用,代码如下: <