对比分析BN和dropout在预测和训练时区别
作者:微笑sun 发布时间:2022-09-05 11:46:55
Batch Normalization和Dropout是深度学习模型中常用的结构。
但BN和dropout在训练和测试时使用却不相同。
Batch Normalization
BN在训练时是在每个batch上计算均值和方差来进行归一化,每个batch的样本量都不大,所以每次计算出来的均值和方差就存在差异。预测时一般传入一个样本,所以不存在归一化,其次哪怕是预测一个batch,但batch计算出来的均值和方差是偏离总体样本的,所以通常是通过滑动平均结合训练时所有batch的均值和方差来得到一个总体均值和方差。
以tensorflow代码实现为例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
# 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
shape = inputs.shape
param_shape = shape[-1]
with tf.variable_scope(name):
# 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
# 计算当前整个batch的均值与方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
# 采用滑动平均更新均值与方差
ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
lambda:(ema.average(batch_mean), ema.average(batch_var)))
# 最后执行batch normalization
return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)
training参数可以通过tf.placeholder传入,这样就可以控制训练和预测时training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在训练时会随机丢弃一些神经元,这样会导致输出的结果变小。而预测时往往关闭dropout,保证预测结果的一致性(不关闭dropout可能同一个输入会得到不同的输出,不过输出会服从某一分布。另外有些情况下可以不关闭dropout,比如文本生成下,不关闭会增大输出的多样性)。
为了对齐Dropout训练和预测的结果,通常有两种做法,假设dropout rate = 0.2。一种是训练时不做处理,预测时输出乘以(1 - dropout rate)。另一种是训练时留下的神经元除以(1 - dropout rate),预测时不做处理。以tensorflow为例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二种做法,训练时除以(1 - dropout rate),源码如下:
binary_tensor = math_ops.floor(random_tensor)
ret = math_ops.div(x, keep_prob) * binary_tensor
if not context.executing_eagerly():
ret.set_shape(x.get_shape())
return ret
binary_tensor就是一个mask tensor,即里面的值由0或1组成。keep_prob = 1 - dropout rate。
来源:https://www.cnblogs.com/jiangxinyang/p/14333903.html
猜你喜欢
- 安装时建议你为MySQL管理创建一个用户和组。由该组用户运行mysql服务器并执行管理任务。(也可以以root身份运行服务器,但是不推荐)第
- 在python代码编写过程中,养成注释的习惯非常有用,可以让自己或别人后续在阅读代码时,轻松理解代码的含义。如果只是简单的单行注释,可直接用
- python爬虫用mongodb的原因:1、文档结构的存储方式简单讲就是可以直接存json,list2、不要事先定义”表”,随时可以创建3、
- 访问FTP,无非两件事情:upload和download,最近在项目中需要从ftp下载大量文件,然后我就试着去实验自己的ftp操作类,如下(
- selenium+python,使用webdriver的截图函数get_screenshot_as_file()截图,代码如下:from s
- objectobject 是 Python 为所有对象提供的父类,默认提供一些内置的属性、方法;可以使用 dir 方法查看新式类以 obje
- mysql安装目录使用MySQL AB's Linux RPM分发进行安装后,将在以下系统目录产生文件目录目录内容/usr/bin客
- 程序设计是困难的,其核心是管理的复杂性。计算机程序是人类做出的最复杂的东西。质量是不可靠的且隐蔽的。好的体系架构是必需给程序足够的结构使其健
- 本文实例为大家分享了Python人脸识别的具体代码,供大家参考,具体内容如下1.利用opencv库sudo apt-get install
- 引言关键!!!!使用loc函数来查找。话不多说,直接演示:有以下名为try.xlsx表:1.根据index查询条件:首先导入的数据必须的有i
- step1:在file中找到default settingsstep2:找到Project Interpreterstep3:按照如图步骤搜
- 前言我们前面对matplotlib模块底层结构学习,对其pyplot类(脚本层)类提供的绘制折线图、柱状图、饼图、直方图等统计图表的相关方法
- 在asp里通过以下两个函数实现javascript里的escape函数和unescape函数
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 13 - Regular ExpressionsM
- 前言本文介绍如何使用Python制作一个简单的猜数字游戏。游戏规则玩家将猜测一个数字。如果猜测是正确的,玩家赢。如果不正确,程序会提示玩家所
- 让你的读者能够方便地收藏你的文章到社会化书签(网摘)网站,如 新浪,google,yahoo,Del.icio.us, 365key等添加到
- 我就废话不多说了,直接上代码吧!import datetimeimport timedef get_float_time_stamp():
- 心血来潮写了个多线程抓妹子图,虽然代码还是有一些瑕疵,但是还是记录下来,分享给大家。Pic_downloader.py# -*- codin
- JDBC连接MySQL数据库关键的四个步骤1、查找驱动程序MySQL目前提供的Java驱动程序为Connection/J,可以从MySQL官
- PYTHON 字节码设计在本篇文章当中主要给大家介绍 cpython 虚拟机对于字节码的设计以及在调试过程当中一个比较重要的字段 co_ln