Pytorch 中net.train 和 net.eval的使用说明
作者:Never-Giveup 发布时间:2021-11-15 11:40:37
在训练模型时会在前面加上:
model.train()
在测试模型时在前面使用:
model.eval()
同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。
训练时是正对每个min-batch的,但是在测试中往往是针对单张图片,即不存在min-batch的概念。
由于网络训练完毕后参数都是固定的,因此每个批次的均值和方差都是不变的,因此直接结算所有batch的均值和方差。
所有Batch Normalization的训练和测试时的操作不同
在训练中,每个隐层的神经元先乘概率P,然后在进行激活,在测试中,所有的神经元先进行激活,然后每个隐层神经元的输出乘P。
补充:Pytorch踩坑记录——model.eval()
最近在写代码时遇到一个问题,原本训练好的模型,加载进来进行inference准确率直接掉了5个点,尼玛,这简直不能忍啊~本菜鸡下意识地感知到我肯定又在哪里写了bug了~~~于是开始到处排查,从model load到data load,最终在一个被我封装好的module的犄角旮旯里找到了问题,于是顺便就在这里总结一下,避免以后再犯。
对于训练好的模型加载进来准确率和原先的不符,比较常见的有两方面的原因:
1)data
2)model.state_dict()
1) data
数据方面,检查前后两次加载的data有没有发生变化。首先检查 transforms.Normalize 使用的均值和方差是否和训练时相同;另外检查在这个过程中数据是否经过了存储形式的改变,这有可能会带来数据精度的变化导致一定的信息丢失。
比如我过用的其中一个数据集,原先将图片存储成向量形式,但其对应的是“png”格式的数据(后来在原始文件中发现了相应的描述。),而我进行了一次data-to-img操作,将向量转换成了“jpg”形式,这时加载进来便造成了掉点。
2)model.state_dict()
第一方面造成的掉点一般不会太严重,第二方面造成的掉点就比较严重了,一旦模型的参数加载错了,那就误差大了。
如果是参数没有正确加载进来则比较容易发现,这时准确率非常低,几乎等于瞎猜。
而我这次遇到的情况是,准确率并不是特别低,只掉了几个点,检查了多次,均显示模型参数已经成功加载了。后来仔细查看后发现在其中一次调用模型进行inference时,忘了写 ‘model.eval()',造成了模型的参数发生变化,再次调用则出现了掉点。
于是又回顾了一下model.eval()和model.train()的具体作用。如下:
model.train() 和 model.eval() 一般在模型训练和评价的时候会加上这两句,主要是针对由于model 在训练时和评价时 Batch
Normalization 和 Dropout 方法模式不同:
a) model.eval(),不启用 BatchNormalization 和 Dropout。此时pytorch会自动把BN和DropOut固定住,不会取平均,而是用训练好的值。不然的话,一旦test的batch_size过小,很容易就会因BN层导致模型performance损失较大;
b) model.train() :启用 BatchNormalization 和 Dropout。 在模型测试阶段使用model.train() 让model变成训练模式,此时 dropout和batch normalization的操作在训练q起到防止网络过拟合的问题。
因此,在使用PyTorch进行训练和测试时一定要记得把实例化的model指定train/eval。
model.eval() vs torch.no_grad()
虽然二者都是eval的时候使用,但其作用并不相同:
model.eval() 负责改变batchnorm、dropout的工作方式,如在eval()模式下,dropout是不工作的。 见下方代码:
import torch
import torch.nn as nn
drop = nn.Dropout()
x = torch.ones(10)
# Train mode
drop.train()
print(drop(x)) # tensor([2., 2., 0., 2., 2., 2., 2., 0., 0., 2.])
# Eval mode
drop.eval()
print(drop(x)) # tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
torch.no_grad() 负责关掉梯度计算,节省eval的时间。
只进行inference时,model.eval()是必须使用的,否则会影响结果准确性。 而torch.no_grad()并不是强制的,只影响运行效率。
来源:https://blog.csdn.net/qq_36653505/article/details/84728489
猜你喜欢
- 这篇文章主要介绍了python智联招聘爬虫并导入到excel代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习
- requests库安装和导入第一步:cmd打开命令行,使用如下命令安装requests库。pip install requests由于我的安
- 为了实现项目中的搜索功能,我们使用的是全文检索框架haystack+搜索引擎whoosh+中文分词包jieba安装和配置安装所需包pip i
- Laravel 中间件提供了一种方便的机制来过滤进入应用的 HTTP 请求。例如,Laravel 内置了一个中间件来验证用户的身份认证。如果
- Selenium 是一个可以让浏览器自动化地执行一系列任务的工具,常用于自动化测试。不过,也可以用来给网页截图。目前,它支持 Java、C#
- 修改MySQL下的默认mysql数据库的user表,删除所有host为localhost记录,另外添加一些其他记录,重新启动MySQL服务器
- 最近要做个从 pdf 文件中抽取文本内容的工具,大概查了一下 python 里可以使用 pdfminer 来实现。下面就看看怎样使用吧。PD
- Dreamweaver(以下简称DW)提供了一种称为“Behavior”(行为)的机制,帮助你构建页面
- 本文介绍基于Python中ArcPy模块,对大量栅格遥感影像文件进行批量掩膜与批量重采样的操作。首先,我们来明确一下本文的具体需求。现有一个
- 我的数据库如图结构我取了其中的name age nr,做成array,只要所取数据存在str型,那么取出的数据,全部转化为str型,也就是a
- #!/usr/bin/env python3# -*- coding: utf-8 -*-import globfrom os import
- python中format的使用format函数这是一种字符串格式化的方法,用法如str.format()。基本语法是通过 {} 和 : 来
- 前言使用Python写过面向对象的代码的同学,可能对 __init__ 方法已经非常熟悉了,__init__方法在类的一个对象被建立时,马上
- <% a="福建是中国的一个省|我们美丽中国的武夷山!" b="中国,我们,武夷山,福建,美国,苹果&q
- 本文实例讲述了Python基于列表list实现的CRUD操作功能。分享给大家供大家参考,具体如下:本篇文章看之前你的先了解python 基础
- python PIL 将数组值转成图片安装 PIL 包pip install pillow将二维数据转换成单通道图片from PIL imp
- 导语大家以前应该都听说过一个游戏:叫做走四棋儿这款游戏出来到现在时间挺长了,小时候的家乡农村条件有限,附近也没有正式的玩具店能买到玩具,因此
- 本文实例讲述了python实现根据窗口标题调用窗口的方法。分享给大家供大家参考。具体分析如下:当你知道一个windows窗口的标题后,可以用
- Mysql的安装方法 安装mysql的步骤如下:请注意按图中所示,有些选项和默认是不一样的。同时,如果您是重新安装mysql的话,要注意先备
- Python是跨平台的,免费开源的一门计算机编程语言。是一种面向对象的动态类型语言,最初被设计用于编写自动化脚本(shell),随着版本的不