BatchNorm2d原理、作用及pytorch中BatchNorm2d函数的参数使用
作者:LS_learner 发布时间:2021-05-28 10:07:19
BN原理、作用
函数参数讲解
BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
1.
num_features
:一般输入参数的shape为batch_size*num_features*height*width,即为其 * 征的数量,即为输入BN层的通道数;2.
eps
:分母中添加的一个值,目的是为了计算的稳定性,默认为:1e-5,避免分母为0;3.
momentum
:一个用于运行过程中均值和方差的一个估计参数(我的理解是一个稳定系数,类似于SGD中的momentum的系数);4.
affine
:当设为true时,会给定可以学习的系数矩阵gamma和beta
一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。
通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。
容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。
其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False则γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True。
trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。
相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。
当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
如果BatchNorm2d的参数track_running_stats设置False,那么加载预训练后每次模型测试测试集的结果时都不一样;track_running_stats设置为True时,每次得到的结果都一样。
running_mean和running_var参数是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。
BN层中的running_mean和running_var的更新是在forward操作中进行的,而不是在optimizer.step()中进行的,因此如果处于训练中泰,就算不进行手动step(),BN的统计特性也会变化。
model.train() #处于训练状态
for data , label in self.dataloader:
pred =model(data) #在这里会更新model中的BN统计特性参数,running_mean,running_var
loss=self.loss(pred,label)
#就算不进行下列三行,BN的统计特性参数也会变化
opt.zero_grad()
loss.backward()
opt.step()
这个时候,要用model.eval()转到测试阶段,才能固定住running_mean和running_var,有时候如果是先预训练模型然后加载模型,重新跑测试数据的时候,结果不同,有一点性能上的损失,这个时候基本上是training和track_running_stats设置的不对。
如果使用两个模型进行联合训练,为了收敛更容易控制,先预训练好模型model_A,并且model_A内还有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时希望model_A中的BN的统计特性量running_mean和running_var不会乱变化,因此就需要将model_A.eval()设置到测试模型,否则在trainning模式下,就算是不去更新模型的参数,其BN都会变化,这将导致和预期不同的结果。
来源:https://blog.csdn.net/qq_39777550/article/details/108038677


猜你喜欢
- 本文实例为大家分享了python实现图像边缘检测的具体代码,供大家参考,具体内容如下任务描述背景边缘检测是数字图像处理领域的一个常用技术,被
- MD5消息摘要算法(英语:MD5 Message-Digest Algorithm),一种被广泛使用的密码散列函数,可以产生出一个128位(
- SQL中的单记录函数 1.ASCII 返回与指定的字符对应的十进制数; SQL> select ascii('A')
- 数据库服务器主要用于存储、查询、检索企业内部的信息,因此需要搭配专用的数据库系统,对服务器的兼容性、可靠性和稳定性等方面都有很高的要求。下面
- 解决问题: 不使用for计算两组、多个矩形两两间的iou使用numpy广播的方法,在python程序中并不建议使用for语句,python中
- 数据库和操作系统一样,是一个多用户使用的共享资源。当多个用户并发地存取数据 时,在数据库中就会产生多个事务同时存取同一数据的情况。若对并发操
- 1. 日志输出到屏幕#!/usr/bin/env python# -*- coding: utf-8 -*-from __future__
- 这篇技术贴讲怎样在Django的框架下导出Excel, 最开始打算用ajax post data 过去,但是发现不行,所以改用了get的方式
- 我在使用conda安装虚拟环境的过程中,下载一些包,比如torch等,发现在虚拟环境中有一份以外,pkgs文件夹下同样也会出现一份,大小一样
- 1.安装PDFminer3k使用pip 命令安装pip install pdfminer3k2.编写测试你可以在这里获得官方参考:PDFMi
- 1. 使用 length 属性追加元素使用length属性,可以在数组末尾后面添加一个元素var arr = [1, 2, 3, 4, 5]
- 问题描述由于之前在安装VSCODE的时候,没注意详细阅读提示,而且第一次安装比较随意,只是带着想试一下VSCODE才安装的,所以安装的时候漏
- 前言ThinkPHP出于安全的考虑增加了表单令牌Token,由于通过Ajax异步更新数据仅仅部分页面刷新数据,就导致了令牌Token不能得到
- 内容摘要:本文介绍了使用js来实现下拉伸缩导航菜单的功能,并带有渐显的效果,值得收藏。正好这几天公司不忙,学校又没有事情,所以想抽空架一个个
- python修改FTP服务器上的文件名,具体代码如下所示:#-*- coding:utf-8 -*-#修改ftp服务器上的文件名from f
- 下面是我已经证实可用的自动备份的方法. 1、打开企业管理器->管理->sql server代理 2、新建一个作业,作业名称随便取
- 本文实例讲述了Python基于checksum计算文件是否相同的方法。分享给大家供大家参考。具体如下:假设有2个二进制文件(0.bin, 1
- 可匹配结构:今天~前天, 几天前, 分钟秒前等 | 2017-1-4 12:10 | 2017/1/4 12:10 | 2018年4月2日
- 研究(2)中讨论了栅格系统的基础知识。这一篇将集中探讨栅格系统的粒度问题。(注:如非特别指明,栅格系统均指24列960栅格系统)淘宝的首页(
- 1、如何使用描述符对实例属性做类型检查?实际案例:在某项目中,我们实现了一些类,并希望能像静态类型语言那样(C,C++,Java)对它们的实