详解model.train()和model.eval()两种模式的原理与用法
作者:想变厉害的大白菜 发布时间:2021-03-20 08:46:56
一、两种模式
pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。
一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。
二、功能
1. model.train()
在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。
如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。
model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。
2. model.eval()
model.eval()的作用是 不启用 Batch Normalization 和 Dropout。
如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。
model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。
为什么测试时要用 model.eval() ?
训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。
eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。
也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。
3. 总结与对比
如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。
其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;
而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。
三、Dropout 简介
dropout 常常用于抑制过拟合。
设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。
来源:https://blog.csdn.net/weixin_44211968/article/details/123774649


猜你喜欢
- 解决的方法:1.在 ueditor\dialogs\internal.js 加入 document.domain = '根域名
- 以下示例显示如何在 XPath 查询中指定轴。这些示例中的 XPath 查询都在 SampleSchema1.xml 中所包含的映射架构上指
- 这本是课程的一个作业研究搜索算法,当时研究了一下Tkinter,然后写了个很简单的机器人走迷宫的界面,并且使用了各种搜索算法来进行搜索,如下
- lambda函数用法lambda非常重要的一个定义。lambda在【运行时】才绑定,【不是】在定义的时候绑定。下面这个列子:本意想:让X分别
- 方法一:#-*- coding:utf-8 -*-from sqlalchemy import create_engineclass mys
- 本文实例讲述了python实现线程池的方法。分享给大家供大家参考。具体如下:原理:建立一个任务队列,然多个线程都从这个任务队列中取出任务然后
- 本文实例为大家分享了python实现大量图片重命名的具体代码,供大家参考,具体内容如下说明在进行深度学习的过程中,需要对图片进行批量的命名处
- 由于网络带宽以及某些WAP服务器DECK传输的限制,所以DECK越小越好,最好不要超过1.2K。如果你的需求很复杂,最好分成几个DECK来完
- 以前写过一个刷校内网的人气的工具,Java的(以后再也不行Java程序了),里面用到了验证码识别,那段代码不是我自己写的:-) 校内的验证是
- 在 Class 块中,成员通过相应的声明语句被声明为 Private(私有成员,只能在类内部调用)
- sql脚本是包含一到多个sql命令的sql语句,我们可以将这些sql脚本放在一个文本文件中(我们称之为“sql脚本文件”),然后通过相关的命
- Golang爬虫框架 colly 简介colly是一个采用Go语言编写的Web爬虫框架,旨在提供一个能够些任何爬虫/采集器/蜘蛛的简介模板,
- QWidget基本介绍基础窗口控件QWidget类是所有用户界面对象的基类,所有的窗口或者控件都直接或者间接的继承自QWidget类。窗口坐
- python去重及数据合并drop_dupicates参数含义:subset:即表示要去重指定参考的列keep : {‘
- MySQL 与 Elasticsearch 数据不对称问题解决办法jdbc-input-plugin 只能实现数据库的追加,对于 elast
- 这篇文章主要介绍了微信小程序顶部导航栏可滑动并选中放大,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 有一个优秀的库可以使用————demjson示范链接http
- 本章我们要制作一个俄罗斯方块游戏。Tetris译注:称呼:方块是由四个小方格组成的俄罗斯方块游戏是世界上最流行的游戏之一。是由一名叫Alex
- Javascript中的变量同样支持自由类型转换成为适用(或者要求)的内容以便于使用。 弱类型的Javascript不会按照程序员的愿望从实
- 引言with 语句是从 Python 2.5 开始引入的一种与异常处理相关的功能(2.5 版本中要通过 from __future__ im