Pytorch中的model.train() 和 model.eval() 原理与用法解析
作者:想变厉害的大白菜 发布时间:2022-06-06 20:51:04
Pytorch中的model.train() 和 model.eval() 原理与用法
一、两种模式
pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train()
和 model.eval()
。
一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。
二、功能
1. model.train()
在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。
如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。
model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。
2. model.eval()
model.eval()的作用是 不启用 Batch Normalization 和 Dropout。
如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。
model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。
为什么测试时要用 model.eval() ?
训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。
eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。
也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。
3. 总结与对比
如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。
其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;
而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。
三、Dropout 简介
dropout 常常用于抑制过拟合。
设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。
参考链接
PyTorch中train()方法的作用是什么
【pytorch】model.train()和model.evel()的用法
pytorch中net.eval() 和net.train()的使用
Pytorch学习笔记11----model.train()与model.eval()的用法、Dropout原理、relu,sigmiod,tanh激活函数、nn.Linear浅析、输出整个tensor的方法
好文:Pytorch:model.train()和model.eval()用法和区别,以及model.eval()和torch.no_grad()的区别
补充:pytroch:model.train()、model.eval()的使用
前言:最近在把两个模型的代码整合到一起,发现有一个模型的代码整合后性能大不如前,但基本上是源码迁移,找了一天原因才发现是因为model.eval()和model.train()放错了位置!!!故在此介绍一下pytroch框架下model.train()、model.eval()的作用和不同点。
一、model.train、model.eval
1.model.train和model.eval放在代码什么位置
简单的说:
model.train
放在网络训练前,model.eval
放在网络测试前。
常见的位置摆放错误(也是我犯的错误)有把model.train()
放在for epoch in range(epoch):
前面,同时在test或者val(测试或者评估函数)中只放置model.eval
,这就导致了只有第一个epoch模型训练是使用了model.train()
,之后的epoch模型训练时都采用model.eval()
.可能会影响训练好模型的性能。
修改方式:可以在test函数里return前面添加model.train()
或者把model.train()
放到for epoch in range(epoch):
语句下面。
model.train()
for epoch in range(epoch):
for train_batch in train_loader:
...
zhibiao = test(epoch, test_loader, model)
def test(epoch, test_loader, model):
model.eval()
for test_batch in test_loader:
...
return zhibiao
2.model.train和model.eval有什么作用
model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。
如果模型中有BN层(Batch Normalization)和Dropout,在测试时添加model.eval()。model.eval()是保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
下面是model.train 和model.eval的源码,可以看到是利用self.training = mode
来判断是使用train还是eval。这个参数将传递到一些常用层,比如dropout、BN层等。
def train(self: T, mode: bool = True) -> T:
r"""Sets the module in training mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.
Returns:
Module: self
"""
self.training = mode
for module in self.children():
module.train(mode)
return self
def eval(self: T) -> T:
r"""Sets the module in evaluation mode.
This has any effect only on certain modules. See documentations of
particular modules for details of their behaviors in training/evaluation
mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
etc.
This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.
Returns:
Module: self
"""
return self.train(False)
拿dropout层的源码举例,可以看到传递了self.training这个参数。
class Dropout(_DropoutNd):
r"""During training, randomly zeroes some of the elements of the input
tensor with probability :attr:`p` using samples from a Bernoulli
distribution. Each channel will be zeroed out independently on every forward
call.
This has proven to be an effective technique for regularization and
preventing the co-adaptation of neurons as described in the paper
`Improving neural networks by preventing co-adaptation of feature
detectors`_ .
Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
training. This means that during evaluation the module simply computes an
identity function.
Args:
p: probability of an element to be zeroed. Default: 0.5
inplace: If set to ``True``, will do this operation in-place. Default: ``False``
Shape:
- Input: :math:`(*)`. Input can be of any shape
- Output: :math:`(*)`. Output is of the same shape as input
Examples::
>>> m = nn.Dropout(p=0.2)
>>> input = torch.randn(20, 16)
>>> output = m(input)
.. _Improving neural networks by preventing co-adaptation of feature
detectors: https://arxiv.org/abs/1207.0580
"""
def forward(self, input: Tensor) -> Tensor:
return F.dropout(input, self.p, self.training, self.inplace)
3.为什么主要区别在于BN层和dropout层
在BN层中,主要涉及到四个需要更新的参数,分别是running_mean,running_var,weight,bias。这里的weight,bias是Pytorch官方实现中的叫法,有点误导人,其实weight就是gamma,bias就是beta。当然它这样的叫法也符合实际的应用场景。其实gamma,beta就是对规范化后的值进行一个加权求和操作running_mean,running_var是当前所求得的所有batch_size下的均值和方差,每经过一个mini_batch我们都会更新running_mean,running_var.为什么要更新它?因为测试的时候,往往是一个一个的图像feed至网络的,如果你在这里对其进行计算均值方差显然是不合理的,所以model.eval()这个语句就是控制BN层中的running_mean,running_std不更新。采用训练结束后的running_mean,running_std来规范化该张图像。
dropout层在训练过程中会随机舍弃一些神经元用来提高性能,但测试过程中如果还是测试的模型还是和训练时一样随机舍弃了一些神经元(不是原模型)这就和测试的本意相违背。因为测试的模型应该是我们最终得到的模型,而这个模型应该是一个完整的模型。
4.BN层和dropout层的作用
既然都讲到这了,不了解一些BN层和dropout层的作用就说不过去了。
BN层的原理和作用建议读一下这篇博客:神经网络中BN层的原理与作用
dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。注意是暂时,对于随机梯度下降来说,由于是随机丢弃,故而每一个mini-batch都在训练不同的网络。
大规模的神经网络有两个缺点:费时、容易过拟合
Dropout的出现很好的可以解决这个问题,每次做完dropout,相当于从原始的网络中找到一个更瘦的网络。因而,对于一个有N个节点的神经网络,有了dropout后,就可以看做是2^n个模型的集合了,但此时要训练的参数数目却是不变的,这就解决了费时的问题。
将dropout比作是有性繁殖,将基因随机进行拆分,可以将优秀的基因传下来,并且降低基因之间的联合适应性,使得复杂的大段大段基因联合适应性变成比较小的一个一个小段基因的联合适应性。
dropout也能达到同样的效果,它强迫一个神经单元,和随机挑选出来的其他神经单元共同工作,达到好的效果。消除减弱了神经元节点间的联合适应性,增强了泛化能力。
参考链接
pytorch中model.train()和model.eval()的区别
BN层(Pytorch)
神经网络中BN层的原理与作用————这篇博客写的贼棒
深度学习中Dropout的作用和原理
来源:https://blog.csdn.net/weixin_44211968/article/details/123774649


猜你喜欢
- 问题描述现有一个有向赋权图。如下图所示:问题:根据每条边的权值,求出从起点s到其他每个顶点的最短路径和最短路径的长度。说明:不考虑权值为负的
- torch.matmul()语法torch.matmul(input, other, *, out=None) → Ten
- 一.克隆表法一mysql> create table info1 like info;复制格式,通过LIKE方法,复制info表结构生
- asp之家注:那么为什么要使用分页呢?当记录不多的时候,如10个或20个,我们可以也没必要使用分页来显示数据,但是数据是在不断增加的,当到了
- vue2.0里,不再有自带的过滤器,需要自己定义过滤器。定义的方法如下: 注册一个自定义过滤器,它接收两个参数:过滤器 ID 和过滤器函数。
- 在待测试的私有函数所在的包内,新建一个xx_test.go文件书写方式如下:import ( "github.com/stretc
- 配置环境: 1、数 据 库:Oracle 8i R2 (8.1.7) for NT 企业版 2、安装路径:C:ORACLE 实现方法: 1.
- 使用命令行时,如果要添加选项的话,python 2.3里新增加了一个模块叫optparse,也是专门来处理命令行选项的。from optpa
- 目录合理的创建索引设置数据库持久连接减少SQL的执行次数仅获取需要的字段数据使用批量创建、更新和删除,不随意对结果排序参考网址:Django
- 目录四种参数仅限关键字参数内省中的函数参数函数注解四种参数Python函数func定义如下:def func(first, *args, s
- 图像标注在计算机视觉中很重要,计算机视觉是一种技术,它允许计算机从数字图像或视频中获得高水平的理解力,并以人类的方式观察和解释视觉信息。注释
- Python时间处理Python在处理与时间相关的操作时有两个重要模块:time和datetime。在本文中,我们介绍这两个模块并为每个场景
- MVC代表: 模型-视图-控制器 。MVC是一个架构良好并且易于测试和易于维护的开发模式。基于MVC模式的应用程序包含:· Models:
- Java开发者对于面向对象编程思维与命令行编程思维的协调程度,取决于他们如下几种能力的水平:技巧(任何人都可以编写命令行形式的代码)教条(有
- 虽然我只是把豆瓣当作一个纪录工具来用,纪录下自己看过的电影、听过的音乐、读过的书籍,我几乎不关注豆瓣上的任何影评、乐评、音衣服之类的内容,但
- dataclass语法一、 简介官方文档的地址为:https://docs.python.org/3.9/library/dataclass
- Python被誉为全世界高效的编程语言,同时也被称作是“胶水语言”,那它为何能如此受欢迎,下面我们就来说说Python入门学习的必备11个知
- 一直以来都对编译器和解析器有着很大的兴趣,也很清楚一个编译器的概念和整体的框架,但是对于细节部分却不是很了解。我们编写的程序源代码实际上就是
- 摘要:本文主要学习了如何使用DBUtils在Java代码中更方便的操作数据库。概述DBUtils是Java编程中的数据库操作实用工具,小巧简
- Fiddler,这个是所有软件开发者必备神器!这款工具不仅可以抓取PC上开发web时候的数据包,而且可以抓取移动端(Android,Ipho