对比分析BN和dropout在预测和训练时区别
作者:微笑sun 发布时间:2022-09-05 11:46:55
Batch Normalization和Dropout是深度学习模型中常用的结构。
但BN和dropout在训练和测试时使用却不相同。
Batch Normalization
BN在训练时是在每个batch上计算均值和方差来进行归一化,每个batch的样本量都不大,所以每次计算出来的均值和方差就存在差异。预测时一般传入一个样本,所以不存在归一化,其次哪怕是预测一个batch,但batch计算出来的均值和方差是偏离总体样本的,所以通常是通过滑动平均结合训练时所有batch的均值和方差来得到一个总体均值和方差。
以tensorflow代码实现为例:
def bn_layer(self, inputs, training, name='bn', moving_decay=0.9, eps=1e-5):
# 获取输入维度并判断是否匹配卷积层(4)或者全连接层(2)
shape = inputs.shape
param_shape = shape[-1]
with tf.variable_scope(name):
# 声明BN中唯一需要学习的两个参数,y=gamma*x+beta
gamma = tf.get_variable('gamma', param_shape, initializer=tf.constant_initializer(1))
beta = tf.get_variable('beat', param_shape, initializer=tf.constant_initializer(0))
# 计算当前整个batch的均值与方差
axes = list(range(len(shape)-1))
batch_mean, batch_var = tf.nn.moments(inputs , axes, name='moments')
# 采用滑动平均更新均值与方差
ema = tf.train.ExponentialMovingAverage(moving_decay, name="ema")
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
# 训练时,更新均值与方差,测试时使用之前最后一次保存的均值与方差
mean, var = tf.cond(tf.equal(training,True), mean_var_with_update,
lambda:(ema.average(batch_mean), ema.average(batch_var)))
# 最后执行batch normalization
return tf.nn.batch_normalization(inputs ,mean, var, beta, gamma, eps)
training参数可以通过tf.placeholder传入,这样就可以控制训练和预测时training的值。
self.training = tf.placeholder(tf.bool, name="training")
Dropout
Dropout在训练时会随机丢弃一些神经元,这样会导致输出的结果变小。而预测时往往关闭dropout,保证预测结果的一致性(不关闭dropout可能同一个输入会得到不同的输出,不过输出会服从某一分布。另外有些情况下可以不关闭dropout,比如文本生成下,不关闭会增大输出的多样性)。
为了对齐Dropout训练和预测的结果,通常有两种做法,假设dropout rate = 0.2。一种是训练时不做处理,预测时输出乘以(1 - dropout rate)。另一种是训练时留下的神经元除以(1 - dropout rate),预测时不做处理。以tensorflow为例。
x = tf.nn.dropout(x, self.keep_prob)
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
tf.nn.dropout就是采用了第二种做法,训练时除以(1 - dropout rate),源码如下:
binary_tensor = math_ops.floor(random_tensor)
ret = math_ops.div(x, keep_prob) * binary_tensor
if not context.executing_eagerly():
ret.set_shape(x.get_shape())
return ret
binary_tensor就是一个mask tensor,即里面的值由0或1组成。keep_prob = 1 - dropout rate。
来源:https://www.cnblogs.com/jiangxinyang/p/14333903.html


猜你喜欢
- 前言Python 中一切皆对象,这些对象的内存都是在运行时动态地在堆中进行分配的,就连 Python 虚拟机使用的栈也是在堆上模拟的。既然一
- 本文实例讲述了Django框架基础模板标签与filter使用方法。分享给大家供大家参考,具体如下:一、基本的模板语言1、变量{{ }}1.1
- 事务隔离级别设置set global transaction isolation level read committed; //全局的se
- mysql_query("set autocommit=0"); $list_one = $db->fetch_f
- 你喜欢在博客文章中使用图片吗?是的,如果不是很麻烦的话,相信大家都不会介意放上几张漂亮的图片来点缀一下内容的,不过你的图片可能会导致下面的两
- 相信大家都用过 jupyter,也用过里面的魔法命令,这些魔法命令都以 % 或者 %% 开
- 配置babel-plugin-import报错的坑用的是antd design vue生成的项目,按着官网的提示一步一步下来,在配置babe
- 程式功能: 用 UI 界面,点击界面上的“开始识别”来录音(调用百度云语音接口),并自动将结果显示在
- 首先我们先引入requests模块import requests一、发送请求r = requests.get('https://ap
- 视频对象提取与其说是视频对象提取,不如说是视频颜色提取,因为其本质还是使用了OpenCV的HSV颜色物体检测。下面话不多说了,来一起看看详细
- 本文介绍了node.js用fs.rename强制重命名或移动文件夹的方法,首先介绍了rename的用法,具体如下:【重命名文件夹】// re
- 看一个例子d={'test':1}d_test=dd_test['test']=2print d如果你在命令
- 这篇文章主要介绍了Python argparse模块使用方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值
- Excel的最合适列宽(openpyxl)Python的Pandas模块是处理Excel的利器,尤其是加工保存Excel非常方便,但是唯独想
- 先看效果,实现一个图片左右摇动,在一般的H5宣传页,商家活动页面我们会看到这样的动画,小程序的动画效果不同于css3动画效果,是通过js来完
- 设置MySQL数据同步(单向&双向)由于公司的业务需求,需要网通和电信的数据同步,就做了个MySQL的双向同步,记下过程,以后用得到
- CORS出于安全性,浏览器限制脚本内发起的跨源 HTTP 请求。例如,XMLHttpRequest 和 Fetch AP
- 本文实例讲述了Django框架实现分页显示内容的方法。分享给大家供大家参考,具体如下:分页1、作用数据加载优化2、前端引入bootstrap
- 啥也不说了,眼泪哗哗的 –来自怨念深重的不灵狗。【运行环境】1、在ubuntu下使用pip安装flask-mongoengine;2、pip
- mysql> SELECT something FROM tbl_name WHERE TO_DAYS(NOW()) – TO_DAY