批标准化层 tf.keras.layers.Batchnormalization()解析
作者:壮壮不太胖^QwQ 发布时间:2023-06-18 23:35:17
批标准化层 tf.keras.layers.Batchnormalization()
tf.keras.layers.Batchnormalization()
重要参数:
training
:布尔值,指示图层应在训练模式还是在推理模式下运行。training=True
:该图层将使用当前批输入的均值和方差对其输入进行标准化。training=False
:该层将使用在训练期间学习的移动统计数据的均值和方差来标准化其输入。
BatchNormalization 广泛用于 Keras 内置的许多高级卷积神经网络架构,比如 ResNet50、Inception V3 和 Xception。
BatchNormalization 层通常在卷积层或密集连接层之后使用。
批标准化的实现过程
求每一个训练批次数据的均值
求每一个训练批次数据的方差
数据进行标准化
训练参数γ,β
输出y通过γ与β的线性变换得到原来的数值
在训练的正向传播中,不会改变当前输出,只记录下γ与β。在反向传播的时候,根据求得的γ与β通过链式求导方式,求出学习速率以至改变权值。
对于预测阶段时所使用的均值和方差,其实也是来源于训练集。比如我们在模型训练时我们就记录下每个batch下的均值和方差,待训练完毕后,我们求整个训练样本的均值和方差期望值,作为我们进行预测时进行BN的的均值和方差。
批标准化的使用位置
原始论文讲在CNN中一般应作用与非线性激活函数之前,但是,实际上放在激活函数之后效果可能会更好。
# 放在非线性激活函数之前
model.add(tf.keras.layers.Conv2D(64, (3, 3)))
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.Activation('relu'))
# 放在激活函数之后
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.BatchNormalization())
tf.keras.layers.BatchNormalization使用细节
关于keras中的BatchNormalization使用,官方文档说的足够详细。本文的目的旨在说明在BatchNormalization的使用过程中容易被忽略的细节。
在BatchNormalization的Arguments参数中有trainable属性;以及在Call arguments参数中有training。两个都是bool类型。第一次看到有两个参数的时候,我有点懵,为什么需要两个?
后来在查阅资料后发现了两者的不同作用。
1,trainable是Argument参数,类似于c++中构造函数的参数一样,是构建一个BatchNormalization层时就需要传入的,至于它的作用在下面会讲到。
2,training参数时Call argument(调用参数),是运行过程中需要传入的,用来控制模型在那个模式(train还是interfere)下运行。关于这个参数,如果使用模型调用fit()的话,是可以不给的(官方推荐是不给),因为在fit()的时候,模型会自己根据相应的阶段(是train阶段还是inference阶段)决定training值,这是由learning——phase机制实现的。
重点
关于trainable=False:如果设置trainable=False,那么这一层的BatchNormalization层就会被冻结(freeze),它的trainable weights(可训练参数)(就是gamma和beta)就不会被更新。
注意:freeze mode和inference mode是两个概念。
但是,在BatchNormalization层中,如果把某一层BatchNormalization层设置为trainable=False,那么这一层BatchNormalization层将一inference mode运行,也就是说(meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).
来源:https://blog.csdn.net/weixin_46072771/article/details/108591263
猜你喜欢
- 一、引用计数基础知识每个php变量存在一个叫 zval 的变量容器中。一个 zval 变量容器,除了包含变量的类型和值,还包括两个字节的额外
- 代码如下:<% '=================================================
- 在使用django restframework serializer 序列化在django中定义的model时,有时候我们需要额外在seri
- 刚开始学习Python的类写法的时候觉得很是麻烦,为什么定义时需要而调用时又不需要,为什么不能内部简化从而减少我们敲击键盘的次数?你看完这篇
- 首先以只读方式打开单词文件,利用列表推导式创建两个列表列表sta记录各单词出现的次数,列表freq记录各单词出现的频率f = open(
- 目录1. 字符串的翻转2. 判断字符串是不是回文串3. 单词大小写4. 字符串的拆分5. 字符串的合并6. 将元素进行重复7. 列表的拓展8
- 原文: gradio.app/interface-s…1.全局状态例子来解释import gradio as grsc
- 具体代码如下所示:import smtplib, email, os, timefrom email.mime.multipart impo
- 前言from collections import namedtuple()命名元祖的工厂函数:在python中,collections 包
- 本文实例介绍了使用javascript来经验表单数据的方法,如:校验是否为英文,校验是否为数字及校验IP地址等: &l
- 我们主要讲解一下利用Python实现感知机算法。算法一首选,我们利用Python,按照上一节介绍的感知机算法基本思想,实现感知算法的原始形式
- 这个可应用于所有浏览器中.<SCRIPT language=javascript>var leave=true; functio
- 推荐使用 Homebrew 来安装第三方工具。自己安装的python散落在电脑各处,删除起来比较麻烦。今天在此记录一下删除的过程(本人以Py
- 全局变量与局部变量# num1是全局变量num1 = 1# num2是局部变量def func():num2 = 2在函数外(且不在函数里)
- 这篇论坛文章(赛迪网技术社区)主要介绍了一种简单的MySQL数据库安装方法,详细内容请大家参考下文:虽然安装MySQL数据库的文章很多,但是
- 以下是一个类文件,下面的注解是调用类的方法注意:如果系统不支持建立Scripting.FileSystemObject对象,那么数据库压缩功
- 什么是图像平滑处理在尽量保留图像原有信息的情况下,过滤掉图像内部的噪声,这一过程我们称之为图像的平滑处理,所得到的图像称为平滑图像。那么什么
- 阅读之前:在看文章具体内容之前,希望你可以 先打开IE8,打开http://www.taobao.com,然后在地址栏里输入:javascr
- 自己用python写了一个签到脚本,经过测试已经可以成功打卡,于是研究了一下windows定时运行程序1. 创建定时任务1.1 计划任务打开
- <?php function getDerivativeByFormulaAndXDATA($formula, $x_data){ $