浅谈PyTorch的可重复性问题(如何使实验结果可复现)
作者:hyk_1996 发布时间:2021-07-16 06:34:33
标签:PyTorch,重复性,结果,复现
由于在模型训练的过程中存在大量的随机操作,使得对于同一份代码,重复运行后得到的结果不一致。因此,为了得到可重复的实验结果,我们需要对随机数生成器设置一个固定的种子。
许多博客都有介绍如何解决这个问题,但是很多都不够全面,往往不能保证结果精确一致。我经过许多调研和实验,总结了以下方法,记录下来。
全部设置可以分为三部分:
1. CUDNN
cudnn中对卷积操作进行了优化,牺牲了精度来换取计算效率。如果需要保证可重复性,可以使用如下设置:
from torch.backends import cudnn
cudnn.benchmark = False # if benchmark=True, deterministic will be False
cudnn.deterministic = True
不过实际上这个设置对精度影响不大,仅仅是小数点后几位的差别。所以如果不是对精度要求极高,其实不太建议修改,因为会使计算效率降低。
2. Pytorch
torch.manual_seed(seed) # 为CPU设置随机种子
torch.cuda.manual_seed(seed) # 为当前GPU设置随机种子
torch.cuda.manual_seed_all(seed) # 为所有GPU设置随机种子
3. Python & Numpy
如果读取数据的过程采用了随机预处理(如RandomCrop、RandomHorizontalFlip等),那么对python、numpy的随机数生成器也需要设置种子。
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
最后,关于dataloader:
注意,如果dataloader采用了多线程(num_workers > 1), 那么由于读取数据的顺序不同,最终运行结果也会有差异。也就是说,改变num_workers参数,也会对实验结果产生影响。目前暂时没有发现解决这个问题的方法,但是只要固定num_workers数目(线程数)不变,基本上也能够重复实验结果。
对于不同线程的随机数种子设置,主要通过DataLoader的worker_init_fn参数来实现。默认情况下使用线程ID作为随机数种子。如果需要自己设定,可以参考以下代码:
GLOBAL_SEED = 1
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
GLOBAL_WORKER_ID = None
def worker_init_fn(worker_id):
global GLOBAL_WORKER_ID
GLOBAL_WORKER_ID = worker_id
set_seed(GLOBAL_SEED + worker_id)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2, worker_init_fn=worker_init_fn)
来源:https://blog.csdn.net/hyk_1996/article/details/84307108


猜你喜欢
- 在平时的工作中,我们的目录有很多的视频文件,如果你没有一个好的视频分类习惯,在找视频素材的时候会很费时,通过对视频的分辨路进行分类可以在需要
- 今天开始学习python,首先环境安装1.在https://www.python.org/downloads/下载python2.X或者3.
- python的数据类型有:数字(int)、浮点(float)、字符串(str),列表(list)、元组(tuple)、字典(dict)、集合
- python 循环while和for in简单实例#!/uer/bin/env python# _*_ coding: utf-8 _*_l
- 题目1、 请输入一个整数 , 若该数是偶数 , 输出 “ 是偶数” ”
- 一、连接MYSQL格式: mysql -h主机地址 -u用户名 -p用户密码1、 连接到本机上的MYSQL。首先打开DOS窗口,然后进入目录
- 最近想学习一些python数据分析的内容,就弄了个爬虫爬取了一些数据,并打算用Anaconda一套的工具(pandas, numpy, sc
- 引言“ 这是MySQL系列笔记的第二篇,文章内容均为本人通过实践及查阅资料相关整理所得,可用作新手入门指南,或
- --利用T-SQL语句,实现数据库的备份与还原的功能 ----体现了SQL Server中的四个知识点: ----1. 获取SQL Serv
- curl 和 Python requests 都是发送 HTTP 请求的强大工具。 虽然 curl 是一种命令行工具,可让您直接从终端发送请
- 前言目前在做vue的项目,用到了子组件依赖其父组件的数据,进行子组件的相关请求和页面数据展示,父组件渲染需要子组件通知更新父组件的state
- 本文实例讲述了Go语言中的匿名结构体用法。分享给大家供大家参考。具体实现方法如下:package main  
- 本文实例总结了Python列表list常用内建函数。分享给大家供大家参考,具体如下:>>> x = list(range(
- 1.使用测量工具,量化性能才能改进性能,常用的timeit和memory_profiler,此外还有profile、cProfile、hot
- python里使用正则表达式的组嵌套实例详解由于组本身是一个完整的正则表达式,所以可以将组嵌套在其他组中,以构建更复杂的表达式。下面的例子,
- 说明:几个简单的基本的sql语句 选择:select * from table1 where 范围 插入:insert into table
- 从今天开始,我将全面的共享出我所能理解的所有WEB标准方面的知识放在这个“WEB标准能有多难?”的专栏里。当然由于振之的水平有限,所讲并非是
- 前言:学过C语言肯定接触过排序问题,我们最常用的也就是冒泡排序、选择排序、插入排序……等等,同样
- 我就废话不多说了,大家还是直接看代码吧!file1 = 'C:\\Users\\Administrator\\Desktop\\te
- 在python中操作文件算是一个基本操作,但是选对了模块会让我们的效率大大提升。本篇整理了两种模块的常用方法,分别是os模块和shutil模