pytorch 实现在测试的时候启用dropout
作者:qian99 发布时间:2022-03-16 18:08:22
标签:pytorch,测试,启用,dropout
我们知道,dropout一般都在训练的时候使用,那么测试的时候如何也开启dropout呢?
在pytorch中,网络有train和eval两种模式,在train模式下,dropout和batch normalization会生效,而val模式下,dropout不生效,bn固定参数。
想要在测试的时候使用dropout,可以把dropout单独设为train模式,这里可以使用apply函数:
def apply_dropout(m):
if type(m) == nn.Dropout:
m.train()
下面是完整demo代码:
# coding: utf-8
import torch
import torch.nn as nn
import numpy as np
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.fc = nn.Linear(8, 8)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.fc(x)
x = self.dropout(x)
return x
net = SimpleNet()
x = torch.FloatTensor([1]*8)
net.train()
y = net(x)
print('train mode result: ', y)
net.eval()
y = net(x)
print('eval mode result: ', y)
net.eval()
y = net(x)
print('eval2 mode result: ', y)
def apply_dropout(m):
if type(m) == nn.Dropout:
m.train()
net.eval()
net.apply(apply_dropout)
y = net(x)
print('apply eval result:', y)
运行结果:
可以看到,在eval模式下,由于dropout未生效,每次跑的结果不同,利用apply函数,将Dropout单独设为train模式,dropout就生效了。
补充:Pytorch之dropout避免过拟合测试
一.做数据
二.搭建神经网络
三.训练
四.对比测试结果
注意:测试过程中,一定要注意模式切换
来源:https://blog.csdn.net/qian99/article/details/89052262


猜你喜欢
- 上一篇自动在Windows中运行Python脚本并定时触发功能实现传送门链接运行Python脚本:.bat文件在Windows中,.bat文
- 以前在网上看到的最简单的拖动对象的代码,忘记作者叫什么了。原始代码在IE下有些小问题,并且声明了文档类型为xhtml 1.0后,在FF等非I
- 游戏开始前的注意事项1:游戏《外星人入侵》将包含很多文件,请在你的D盘中新建一个空文件夹,并将其命名为alien_invasion.请务必将
- windows系统MySQL安装教程下载1.登录https://dev.mysql.com/downloads/installer/选择Mi
- 上次在blueidea上看到一个元素圆角的实现方法,但是那个太复杂了。于是就自己写了一个函数,可以将元素自动圆角,如div层,表格等。共有四
- Kettle简介Kettle最早是一个开源的ETL(Extract-Transform-Load的缩写)工具,全称为KDE Extracti
- 一、 在数据库排序查询优化上的差异。在讲解这个内容之前,为了读者能够清楚我讲的内容,我要先谈一个概念。命中率,它是指从内存中取得数据而不从磁
- 一、初始化CounterCounter支持3种形式的初始化,比如提供一个数组,一个字典,或单独键值对“=”式赋值。具体初始化的代码如下所示:
- 阅读本文大概需要3分钟关于函数和模块讲了这么久,我一直想用一个好玩有趣的小例子来总结一下,同时也作为实战练习一下。趣味编程其实是最好的学习途
- 1、先看最简单的场景,生产者生产消息,消费者接收消息,下面是生产者的简单代码。#!/usr/bin/env python# -*- codi
- DDP 数据shuffle 的设置使用DDP要给dataloader传入sampler参数(torch.utils.data.distrib
- 本文实例讲述了Python Zip和Enumerate用法。分享给大家供大家参考,具体如下:Python 中的 Zipzip的作用:可以在处
- 先看一个实例这是我用asp写的一个搜索一个字符串里面第一张图片地址的函数(当然你可以将values那里换一个得到所有图片地址)functio
- 本文实例讲述了Python延时操作实现方法。分享给大家供大家参考,具体如下:在日常的开发中,往往会遇到这样的需求,需要某一个函数在一段时间之
- 今天研究了些取access数据库随机记录问题,这是这我自己搜集整理的方法。大家有没有高见,可以告诉我,或者我总结的东东本身有误,也可以帮我修
- 1. JSON简介JSON(JavaScript Object Notation) 是一种轻量级的数据交换格式,它是JavaScript的子
- 简介有些 post 的请求参数是 json 格式的,这个前面发送post 请求里面提到过,需要导入 json模块处理。现在企业公司一般常见的
- 导入同级模块导入sys,一定要将当前包所在路径添加进来。import syssys.path.append(r"directory
- 使用MySQL,安全问题不能不注意。以下是MySQL提示的23个注意事项:1.如果客户端和服务器端的连接需要跨越并通过不可信任的网络,那么就
- 实现功能:删除当前目录下,除保留目录和文件外的所有文件和目录#!bin/env pythonimport osimport os.pathi