BatchNorm2d原理、作用及pytorch中BatchNorm2d函数的参数使用
作者:LS_learner 发布时间:2021-05-28 10:07:19
BN原理、作用
函数参数讲解
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
1.
num_features
:一般输入参数的shape为batch_size*num_features*height*width,即为其 * 征的数量,即为输入BN层的通道数;2.
eps
:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5,避免分母为0;3.
momentum
:一个用于运行过程中均值和方差的一个估计参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数);4.
affine
:当设为true时,会给定可以学习的系数矩阵gamma和beta
一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。
通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。
容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。
其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。
trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。
相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。
当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
如果BatchNorm2d的参数track_running_stats设置False,那么加载预训练后每次模型测试测试集的结果时都不一样;track_running_stats设置为True时,每次得到的结果都一样。
running_mean和running_var参数是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。
BN层中的running_mean和running_var的更新是在forward操作中进行的,而不是在optimizer.step()中进行的,因此如果处于训练中泰,就算不进行手动step(),BN的统计特性也会变化。
model.train() #处于训练状态
for data , label in self.dataloader:
pred =model(data) #在这里会更新model中的BN统计特性参数,running_mean,running_var
loss=self.loss(pred,label)
#就算不进行下列三行,BN的统计特性参数也会变化
opt.zero_grad()
loss.backward()
opt.step()
这个时候,要用model.eval()转到测试阶段,才能固定住running_mean和running_var,有时候如果是先预训练模型然后加载模型,重新跑测试数据的时候,结果不同,有一点性能上的损失,这个时候基本上是training和track_running_stats设置的不对。
如果使用两个模型进行联合训练,为了收敛更容易控制,先预训练好模型model_A,并且model_A内还有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时希望model_A中的BN的统计特性量running_mean和running_var不会乱变化,因此就需要将model_A.eval()设置到测试模型,否则在trainning模式下,就算是不去更新模型的参数,其BN都会变化,这将导致和预期不同的结果。
来源:https://blog.csdn.net/qq_39777550/article/details/108038677
猜你喜欢
- 本XML系列教程将分三部分发布,到最后一期我们将拥有一个功能全面,更加友好的XML菜单。本教程这个第一期涉及到了一些XML的基础知识。大家都
- 我的长博文不少,比较影响阅读体验,有必要添加一个文章目录功能。相比 Wordpress, Typecho 的插件就比较少了。我想找一个像掘金
- 前言之前工作中主要使用的是 Tensorflow 1.15 版本,但是渐渐跟不上工作中的项目需求了,而且因为 2.x 版本和 1.x 版本差
- 因为有把python程序打包成exe的需求,所以,有了如下的代码import timeclass LoopOver(Exception):
- 本文介绍了redis之django-redis的简单缓存使用,分享给大家,具体如下:自定义连接池这种方式跟普通py文件操作redis一样,代
- pandas创建series方法print("====创建series方法一===")dic={"a"
- 许多网页开发者想从ASP.NET 页面传递一个值到另一个页面(比如从一个框架frame页面到一个弹窗页面)。看了代码就明白了。呵呵。(一)向
- 1、文件编码:指的是页面文件(.html,.php等)本身是以何种编码来保存的。记事本和Dreamweaver在打开页面时候会自动识别文件编
- 一、前言在Python中,除了可以自定义模块外,还可以引用其他模块,主要包括使用标准库和第三方模块。下面分别进行介绍。二、导入和使用标准模块
- DBA_2PC_PENDING Oracle会自动处理分布事务,保证分布事务的一致性,所有站点全部提交或全部回滚。一般情况下,处理过程在很短
- 前言最近在使用Pycharm,在运行或者安装的过程中出现了各种各样的报错,前面已经介绍过安装pygame出现报错的解决方法。文章总结了大部分
- 目录1.引言2.获取目标网站3.爬取目标网站4.解析爬取内容4.1. 解析全国今日总况4.2. 解析全国各省份疫情情况4.3. 解析江苏各地
- 一、业务背景在金融风控领域,常常使用KS指标来衡量评估模型的区分度(discrimination),这也是风控模型最为追求的指标之一。下面将
- 先来看个用Python实现的二分查找算法实例import sys def search2(a,m): low = 0 high = le
- identity-card验证身份证号码的正确性,不能仅仅通过正则表达式来验证,我们都知道我国的身份证一共是18位,由十七位数字本体码和一位
- argparse是python标准库里面用来处理命令行参数的库命令行参数分为位置参数和选项参数:位置参数就是程序根据该参数出现的位置来确定的
- 原文:http://www.htmldog.com/guides/htmlintermediate/badtags/十六 有害的标签 Bad
- 本文实例讲述了python获取mp3文件信息的方法。分享给大家供大家参考。具体如下:将代码生成.py文件放在目录下运行,可以获取该目录的所有
- 一、利用外键约束更新MySQL中的数据现在,最流行的开源关系型数据库管理系统非MySQL莫属,而MySQL又支持多个存储引擎,其中默认的也是
- 有四个变量影响磁带备份设备的性能,并使 SQL Server 备份及还原性能操作得以在大体上随添加更多磁带设备而提高线性比例。◆软件数据块大