Pytorch 中net.train 和 net.eval的使用说明
作者:Never-Giveup 发布时间:2021-11-15 11:40:37
在训练模型时会在前面加上:
model.train()
在测试模型时在前面使用:
model.eval()
同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。
训练时是正对每个min-batch的,但是在测试中往往是针对单张图片,即不存在min-batch的概念。
由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。
所有Batch Normalization的训练和测试时的操作不同
在训练中,每个隐层的神经元先乘概率P,然后在进行激活,在测试中,所有的神经元先进行激活,然后每个隐层神经元的输出乘P。
补充:Pytorch踩坑记录——model.eval()
最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,尼玛,这简直不能忍啊~本菜鸡下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。
对于训练好的模型加载进来准确率和原先的不符,比较常见的有两方面的原因:
1)data
2)model.state_dict()
1) data
数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。
比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。
2)model.state_dict()
第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。
如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。
而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()',造成了模型的参数发生变化,再次调用则出现了掉点。
于是又回顾了一下model.eval()和model.train()的具体作用。如下:
model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch
Normalization 和 Dropout 方法模式不同:
a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;
b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题。
因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。
model.eval() vs torch.no_grad()
虽然二者都是eval的时候使用,但其作用并不相同:
model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。 见下方代码:
import torch
import torch.nn as nn
drop = nn.Dropout()
x = torch.ones(10)
# Train mode
drop.train()
print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])
# Eval mode
drop.eval()
print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
torch.no_grad() 负责关掉梯度计算,节省eval的时间。
只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。
来源:https://blog.csdn.net/qq_36653505/article/details/84728489


猜你喜欢
- 下载地址:安装包可以从这里下载:http://www.itellyou.cn/SQL Server 2016 Enterprise with
- 内容摘要:本文详细介绍了SQL Server导入导出数据的方法:(1)导出导入SQL Server里某个数据库,(2)导
- 防止Application对象在多线程访问中出现错误asp代码处理代码如下(VB):<%Application.Lock()Appli
- 雪花算法是在一个项目体系中生成全局唯一ID标识的一种方式,偶然间看到了Python使用雪花算法不尽感叹真的是太便捷了。它生成的唯一ID的规则
- 最近有个小项目,需要爬取页面上相应的资源数据后,保存到本地,然后将原始的HTML源文件保存下来,对HTML页面的内容进行修改将某些标签整个给
- 现在电子商务网站的设计,正面临着一系列的挑战,其中最主要的挑战是:我们尝试建立一种用户体验,来提高用户在线购物的可能性。为了对抗网上激烈的竞
- 一、http请求的顺序处理方式在高并发场景下,为了降低系统压力,都会使用一种让请求排队处理的机制。本文就介绍在Go中是如何实现的。首先,我们
- JSON 相关概念:序列化(Serialization):将对象的状态信息转换为可以存储或可以通过网络传输的过程,传输的格式可以是JSON,
- 1、封装的理解封装(Encapsulation):属性和方法的抽象属性的抽象:对类的属性(变量)进行定义、隔离和保护分为私有属性和公开属性:
- 一、MySQL数据库模块的安装和连接1、 PyMySQL模块的安装pip install pymysql2 、python连接数据库impo
- 我们通常情况下要统计数据库的连接数指的是统计总数,没有细分到每个IP上。现在要监控每个IP的连接数,实现方式如下:方法一:select SU
- numpy随机打乱数据方法np.random.shuffleimport numpy as np#实验可得每次shuffle后数据都被打乱,
- 最近vue更新的2.0版本,唉,我是在2.0版本前学习的,现在更新了又要看一遍了,关键是我之前看了3个星期2.0就更新了,vux还没同步更新
- 一、TensorBoardTensorBoard 一般都是作为 TensorFlow 的可视化工具,与 TensorFlow 深度集成,它能
- 前言:目前在研究易信公众号,想给公众号增加一个获取个人交通违章的查询菜单,通过点击返回查询数据。以下是实施过程。一、首先,用火狐浏览器打开X
- 前言:在开发中经常会与时间打交道,如:获取事件戳,时间戳的格式化等,这里简要记录一下python操作时间的方法。python中常见的处理时间
- 一般上电子商务网站买东西的用户分三种:随便看看,就是不买先看看,买不买再说就是来买东西的这样的需求反应到产品页的购买按钮上,我们一般会看到购
- 背景自从把我手上的任务全部转换成docker运行和管理之后,遇到了一系列的坑,这次是mysql备份的问题。原因是启动mysql镜像的时候没有
- 有一组4096长度的数据,需要找到一阶导数从正到负的点,和三阶导数从负到正的点,截取了一小段。394.0 388.0 389.0 388.0
- 首先来看一个例子,正常情况下我们定义并且实例一个类如下class Foo(object):def __init__(self):