踩坑:pytorch中eval模式下结果远差于train模式介绍
作者:yucong96 发布时间:2021-10-06 22:27:49
首先,eval模式和train模式得到不同的结果是正常的。我的模型中,eval模式和train模式不同之处在于Batch Normalization和Dropout。Dropout比较简单,在train时会丢弃一部分连接,在eval时则不会。Batch Normalization,在train时不仅使用了当前batch的均值和方差,也使用了历史batch统计上的均值和方差,并做一个加权平均(momentum参数)。在test时,由于此时batchsize不一定一致,因此不再使用当前batch的均值和方差,仅使用历史训练时的统计值。
我出bug的现象是,train模式下可以收敛,但一旦在测试中切换到了eval模式,结果就很差。如果在测试中仍沿用train模式,反而可以得到不错的结果。为了确保是程序bug而不是算法本身就不适合于预测,我在测试时再次使用了训练集,正常情况下此时应发生过拟合,正确率一定会很高,然而eval模式下正确率仍然很低。参照网上的一些说法(Performance highly degraded when eval() is activated in the test phase
),我调大了batchsize,降低了BN层的momentum,检查了是否存在不同层使用相同BN层的bug,均不见效。有一种方法说应在BN层设置track_running_stats为False,它虽然带来了好的效果,但实际上它只不过是不用eval模式,切回train模式罢了,所以也不对。
学习了在训练过程中,如何将BN层中统计的均值和方差输出。即在forward()中,
# bn是一个BN层,torch.nn.batch_normalization(...)
print(bn.running_mean)
print(bn.running_var)
同时学习了如何输出一个Tensor自身的均值和方差,即
# x是一个Tensor,dims是需要计算的维度
print(x.cpu().detach().numpy().mean(dims)
print(x.cpu().detach().numpy().var(dims)
观察每一层的输出结果,发现出现了很大的方差,才猛然意识到自己的输入数据没有做归一化(事后想想也确实如此,毕竟模型和训练方法都是github上参考别人的,出错概率很小;反而是自己写的DataSet部分,其实是最容易出错的)。给模型加上归一化后,eval和train的结果就没有问题了。
再次验证了我的观点:越是玄学的问题,越是 * 的bug。
补充知识:Pytorch中的train和eval用法注意点
1.介绍
一般情况,model.train()是在训练的时候用到,model.eval()是在测试的时候用到
2.用法
如果模型中没有类似于BN这样的归一化或者Dropout,model.train()和model.eval()可以不要(建议写一下,比较安全),并且model.train()和model.eval()得到的效果是一样
如果模型中有类似于BN这样的归一化或者Dropout,并且程序需要边训练和边测试,最好就是用model.eval()测试完之后,后面补一个model.train()。
其中model.train()是保证BN用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接(结果是取了平均)
来源:https://blog.csdn.net/yucong96/article/details/88652964


猜你喜欢
- javascript的分号代表语句的结束符,但由于javascript具有分号自动插入规则,所以它是一个十分容易让人模糊的东西,在一般情况下
- 环境配置系统:Windows10版本:python 3.8Turtle扫盲1.绘图窗体的设置turtle.setup(width, heig
- 自动发送邮件我们把报表做出来以后一般都是需要发给别人查看,对于一些每天需要发的报表或者是需要一次发送多份的报表,这个时候可以考虑借助Pyth
- 需求:需求简单:但是感觉最后那部分遍历有意思:S型数组赋值,考虑到下标,简单题先实现个差不多的m = 5cols = 9rows = 4nu
- 我们直接先给出输出与预期不同的代码In[28]: a = [1,2,3,4,5,6]In[29]: for i in a: ...: &nb
- Python HTTP客户端自定义Cookie实现实例几乎所有脚本语言都提供了方便的 HTTP 客户端处理的功能,Python 也不例外,使
- 前言最近的一个项目中需要在图片上添加文字,使用了OpenCV,结果发现利用opencv给图像添加文字有局限。可利用的字体类型比较少,需要安装
- 废话不多说,直接上代码Python2.7#!/usr/bin/env python2.7# -*- coding=utf-8 -*-impo
- 官方文档中关于super的定义说的不是很多,大致意思是返回一个代理对象让你能够调用一些继承过来的方法,查找的机制遵循mro规则,最常用的情况
- #-*- coding: UTF-8 -*-'''Created on 2013-12-5@author: good
- PHP 中文工具类,支持汉字转拼音、拼音分词、简繁互转。PHP Chinese Tool class, support Chinese pi
- 问题你的程序获取了一个目录中的文件名列表,但是当它试着去打印文件名的时候程序崩溃, 出现了 UnicodeEncodeError 异常和一条
- OpenCV的imread不能读取中文路径问题import numpy as npimport cv2cv_img = cv2.imdeco
- 前言在上下文管理器协议的过程中,涉及到两个魔术方法__enter__方法 和 __exit__方法在python中所有实现了上下文管理器协议
- 1 什么是嵌套循环所谓嵌套循环就是一个外循环的主体部分是一个内循环。内循环或外循环可以是任何类型,例如 while 循环或 for 循环。
- --使用说明 本代码适用于MsSql2000,对于其它数据库也可用.但没必要 --创建存储过程 CREATE PROCEDURE pagin
- 遍历pd.Series的index和value的方法如下,python built-in list的enumerate方法不管用for i,
- 本文实例讲述了PHP实现无限极分类的两种方式。分享给大家供大家参考,具体如下:面试的时候被问到无限极分类的设计和实现,比较常见的做法是在建表
- 1.因为oracle 10g暂时没有与win7兼容的版本,我们可以通过对安装软件中某些文件的修改达到安装的目地。 a)打开“\ORACLE1
- 直接将 视频的HTML网址存入models ,以字符串的形式#关于我们 CharFieldclass About(models.Model)