浅谈PyTorch的可重复性问题(如何使实验结果可复现)
作者:hyk_1996 发布时间:2021-07-16 06:34:33
标签:PyTorch,重复性,结果,复现
由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。
许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。
全部设置可以分为三部分:
1. CUDNN
cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:
from torch.backends import cudnn
cudnn.benchmark = False # if benchmark=True, deterministic will be False
cudnn.deterministic = True
不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。
2. Pytorch
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
3. Python & Numpy
如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
最后,关于dataloader:
注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。
对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:
GLOBAL_SEED = 1
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
global GLOBAL_WORKER_ID
GLOBAL_WORKER_ID = worker_id
set_seed(GLOBAL_SEED + worker_id)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)
来源:https://blog.csdn.net/hyk_1996/article/details/84307108
0
投稿
猜你喜欢
- 前言闲暇时间抽个空写了个三国杀武将手册的小程序,中间有个需求设计的是合成武将皮肤图、竖排的武将姓名、以及小程序码,然后提供保存图片到相册,最
- 1、简介pyqt 列表 单元格中 不仅可以添加数据,还可以添加控件。我们尝试添加下拉列表、一个按钮试试。setItem:将文本放到单元格中s
- 在利用QT编写GUI程序时经常需要一些交互操作,常见的有鼠标事件、键盘事件等。今天我们要实现的是在label中已经显示的图像中绘制矩形框,以
- aspjpeg组件官方下载地址:http://www.persits.com/说明: 1、aspjpeg能对图片水印进行透明度调整
- 从小的方面讲,帮助一般是指:手册、说明书、文档、FAQ 等等。从大的方面讲,可以是交互过程中的提示、指引、演示等信息,帮助无处不在!这一切,
- 1. 引言Python程序有许多模块和第三方包,这非常有助于高效编程。了解这些模块的正确使用方法是很重要的,在本文中,主要介绍一些非常实用的
- 研究(2)中讨论了栅格系统的基础知识。这一篇将集中探讨栅格系统的粒度问题。(注:如非特别指明,栅格系统均指24列960栅格系统)淘宝的首页(
- 学习JQUERY就应该从最基本的学起,基本的就应该是语法了,在这里,我们有必要先温习一下JAVASCRIPT的一些知识。语法就不用说了,都是
- python控制鼠标键盘其实很容易,我们在写程序的时候很多时候会用的到!python控制鼠标键盘步骤及代码1、安装类库pip install
- 有时候需要比较大的计算量,这个时候Python的效率就很让人捉急了,此时可以考虑使用numba 进行加速,效果提升明显~(numba 安装貌
- 什么是拼音转换在我们学习语言之前,我们一般会学习拼音来认识汉字,并学会如何读汉字。所以,拼音在对于我们语言的重要性不言而喻。而拼音转换指的是
- 一、Session 的概念cookie 是在浏览器端保存键值对数据,而 session 是在服务器端保存键值对数据 session 的使用依
- 一、安装我们知道selenium是桌面浏览器自动化操作工具(Web Browser Automation)appium是继承selenium
- Dreamweaver MX 2004的强大功能以及更加完善的人性化设置已经深受大家喜爱。在此笔者就谈
- 1.如何用函数先定义后调用,定义阶段只检测语法,不执行代码调用阶段,开始执行代码函数都有返回值定义时无参,调用时也是无参定义时有参,调用时也
- /* --注意:准备数据(可略过,非常耗时) CREATE TABLE CHECK1_T1 ( ID INT, C1 CHAR(8000)
- 在/etc/profile.d/简历oracle.sh内容如下在NLS_LANG设置编码ORACLE_HOME=/usr/lib/oracl
- JWT是一种JSON的行业标准,广泛应用在系统的用户认证方面。JWT认证简介JWT(JSON Web Tokens),是为了在网络应用环境间
- python下redis安装用python操作redis数据库,先下载redis-py模块下载地址https://github.com/an
- 一、相关知识点讲解1.1 需要使用的相关库import numpy as npimport pand