pytorch中的model.eval()和BN层的使用
作者:那抹阳光1994 发布时间:2023-09-21 17:06:10
标签:pytorch,model.eval,BN层
看代码吧~
class ConvNet(nn.module):
def __init__(self, num_class=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.fc = nn.Linear(7*7*32, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
print(out.size())
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
# Test the model
model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。
评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。
训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。
补充:Pytorch 模型训练模式和eval模型下差别巨大(Pytorch train and eval)附解决方案
当pytorch模型写明是eval()时有时表现的结果相对于train(True)差别非常巨大,这种差别经过逐层查看,主要来源于使用了BN,在eval下,使用的BN是一个固定的running rate,而在train下这个running rate会根据输入发生改变。
解决方案是冻住bn
def freeze_bn(m):
if isinstance(m, nn.BatchNorm2d):
m.eval()
model.apply(freeze_bn)
这样可以获得稳定输出的结果。
来源:https://www.cnblogs.com/jiangkejie/p/9983451.html


猜你喜欢
- 一、下载xlsx插件npm i xlsx二、通过element-ui组件的upload组件上传文件<el-upload
- 一、前言在项目开发中,数据库应用必不可少。虽然数据库的种类有很多,如SQLite、MySQL、Oracle等,但是它们的功能基本是一样都是一
- 阅读对象:知道什么是restful,有了解swagger或者openAPI更佳。1.什么是restfulRepresentional Sta
- matplotlib及相关cmap参数的取值在matplotlib中对于图片的显示有如下方法(这不是重点), 其中有cmap=&ls
- 做图像识别的时候需要在图片中画出特定大小和角度的矩形框,自己写了一个函数,给定的输入是图片名称,矩形框的位置坐标,长宽和角度,直接输出画好矩
- 这几天不是很忙,就找了些拖动布局方面的资料看看,也学着写了个拖动布局的效果,没想到花了好多时间,七拼八凑,总算是把这个效果写出来了。哎!还是
- PEP 3107引入了功能注释的语法,PEP 484 加入了类型检查标准库 typing 为类型提示指定的运行时提供支持。示例:def f(
- 本例使用登录页面演示,session的状态保持功能。说明:因为http是无状态的,客户端请求一次页面后,就结束了,当再次访问时,服务器端并不
- Python中的三引号,3个单引号及3个双引号实际上3个单引号和3个双引号不经常用,但是在某些特殊格式的字符串下却有大用处。通常情况下我们用
- 1、建表语句:CREATE TABLE `employees` ( `emp_no` int(11) NOT NULL, `birth_da
- 找遍资料得出结果:不能 不过同时也找到了解决办法,就是用iframe的方式来提交表单,即实现无刷新提交表单又可以上传文件! 一、HTML代码
- 前言在AI领域,来快速实现一个idea:前后端开发+部署+展现,如果走传统的前后端分离开发+服务器docker部署等方式,会很重且入门成本很
- 本文实例讲述了redis数据库及与python交互用法。分享给大家供大家参考,具体如下:redis数据操作1.string类型:主要存储字符
- 分布式锁一般有三种实现方式:1. 数据库乐观锁;2. 基于Redis的分布式锁;3. 基于ZooKeeper的分布式锁。本篇博客将介绍第二种
- 本文实例讲述了python清除字符串里非字母字符的方法。分享给大家供大家参考。具体如下:s = "hello world! how
- 前言上篇文章 一文了解 Go 标准库 strings 常用函数和方法 介绍了 strings 标注库里的一些常用的函数和方法,本文也是以 s
- 本文实例为大家分享了python字典操作实例的具体代码,供大家参考,具体内容如下#!/usr/bin/env python3# -*- co
- onactivate
- PyQt5中信号与槽可以说是对事件处理机制的高级封装,如果说事件是用来创建窗口控件的,那么信号与槽就是用来对这个控件进行使用的,比如一个按钮
- 这篇博客将介绍如何通过OpenCV和Python使用模板匹配执行光学字符识别(OCR)。具体来说,将使用Python+OpenCV实现模板匹