pytorch中关于distributedsampler函数的使用
作者:DRACO于 发布时间:2023-01-18 01:10:01
关于distributedsampler函数的使用
1.如何使用这个分布式采样器
在使用distributedsampler函数时,观察loss发现loss收敛有规律,发现是按顺序读取数据,未进行shuffle。
问题的解决方式就是怀疑 seed 有问题,参考源码 DistributedSampler,发现 shuffle 的结果依赖 g.manual_seed(self.epoch) 中的 self.epoch。
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
而 self.epoch 初始默认是 0
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
但是 DistributedSampler 也提供了一个 set 函数来改变 self.epoch
def set_epoch(self, epoch):
self.epoch = epoch
所以在运行的时候要不断调用这个 set_epoch 函数。只要把我的代码中的
# sampler.set_epoch(e)
全部代码如下:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
input_size = 5
output_size = 2
batch_size = 2
data_size = 16
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, size, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(5), torch.ones(5)*2,
torch.ones(5)*3,torch.ones(5)*4,
torch.ones(5)*5,torch.ones(5)*6,
torch.ones(5)*7,torch.ones(5)*8,
torch.ones(5)*9, torch.ones(5)*10,
torch.ones(5)*11,torch.ones(5)*12,
torch.ones(5)*13,torch.ones(5)*14,
torch.ones(5)*15,torch.ones(5)*16]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(input_size, data_size, local_rank)
sampler = DistributedSampler(dataset)
rand_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler)
e = 0
while e < 2:
t = 0
# sampler.set_epoch(e)
for data in rand_loader:
print(data)
e+=1
运行:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 test.py
2.关于用不用这个采样器的区别
多卡去训模型,尝试着用DDP模式,而不是DP模式去加速训练(很容易出现负载不均衡的情况)。
遇到了一点关于DistributedSampler这个采样器的一点疑惑,想试验下在DDP模式下,使用这个采样器和不使用这个采样器有什么区别。
实验代码:
整个数据集大小为8,batch_size 为4,总共跑2个epoch。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
batch_size = 4
data_size = 8
local_rank = torch.distributed.get_rank()
print(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(1), torch.ones(1)*2,torch.ones(1)*3,torch.ones(1)*4,torch.ones(1)*5,torch.ones(1)*6,torch.ones(1)*7,torch.ones(1)*8]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(data_size, local_rank)
sampler = DistributedSampler(dataset)
#rand_loader =DataLoader(dataset=dataset,batch_size=batch_size,sampler=None,shuffle=True)
rand_loader = DataLoader(dataset=dataset,batch_size=batch_size,sampler=sampler)
epoch = 0
while epoch < 2:
sampler.set_epoch(epoch)
for data in rand_loader:
print(data)
epoch+=1
运行命令:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 test.py
实验结果:
结论分析:上面的运行结果来看,在一个epoch中,sampler相当于把整个数据集 划分成了nproc_per_node份,每个GPU每次得到batch_size的数量,也就是nproc_per_node 个GPU分一整份数据集,总数据量大小就为1个dataset。
如果不用它里面自带的sampler,单纯的还是按照我们一般的形式。Sampler=None,shuffle=True这种,那么结果将会是下面这样的:
结果分析:没用sampler的话,在一个epoch中,每个GPU各自维护着一份数据,每个GPU每次得到的batch_size的数据,总的数据量为2个dataset,
来源:https://blog.csdn.net/chanbo8205/article/details/115242635


猜你喜欢
- /* --注意:准备数据(可略过,非常耗时) CREATE TABLE CHECK1_T1 ( ID INT, C1 CHAR(8000)
- 本文实例讲述了JS简单实现无缝滚动效果。分享给大家供大家参考,具体如下:<!doctype html><title>
- 本文实例讲述了Python实现从订阅源下载图片的方法。分享给大家供大家参考。具体如下:这段代码是基于python 3.4实现的,和pytho
- 一、软件下载MySQL下载安装:官网下载地址:https://www.mysql.com/或者本地下载二、安装须知如果是安装过该软件的卸载重
- 一、Python下载1.进入Python官网:https://www.python.org/2.选择windows版本(Download &
- 创建测试dataframe:>>> import pandas as pd>>> df = pd.Dat
- 1.首先自己直接在cmd中输入 pip3 install openCV是不可行的,即需要自己下载安装包本地安装2.openCV库 下载地址h
- 当你在IE中点击一个Realplayer连接时,系统会自动启动Realplayer软件,不仅占用系统内存,而且在上网时Realplayer容
- 本文实例讲述了python中getaddrinfo()基本用法。分享给大家供大家参考。具体如下:import sys, socketresu
- 先看一下总体效果:上传文件做了大小和类型的限制,在动图中无法展现出来。使用file类型的input实现选择本地文件但是浏览器原生的文件上传按
- jQuery由美国人John Resig创建,至今已吸引了来自世界各地的众多javascript高手加入其team,包括来自德国的J&
- Python 爬虫包含两个重要的部分:正则表达式和Scrapy框架的运用, 正则表达式对于所有语言都是通用的,网络上可以找到各种资源。如下是
- 在mysql中查询5条不重复的数据,使用以下:SELECT * FROM `table` ORDER BY RAND() LIMIT 5就可
- Python是动态语言,在创建对象后,可以动态地绑定属性和方法定义类:class Student: #定义类 &nb
- 虽然python是万能的,但是对于某些特殊功能,需要c语言才能完成。这样,就需要用python来调用c的代码了具体流程:c编写相关函数 ,编
- 目前防采集的方法有很多种,先介绍一下常见防采集策略方法和它的弊端及采集对策: 一、判断一个IP在一定时间内对本站页面的访问次数,如果明显超过
- 安装时建议你为MySQL管理创建一个用户和组。由该组用户运行mysql服务器并执行管理任务。(也可以以root身份运行服务器,但是不推荐)第
- 对于regex库的使用不难,因为本身就是python中自带的库,所以在调用上也是常见的库使用类型,大部分时候都是用于搜索上下文信息的,但是有
- 数据库发生阻塞和死锁的现象:一、数据库阻塞的现象:第一个连接占有资源没有释放,而第二个连接需要获取这个资源。如果第一个连接没有提交或者回滚,
- 根据官网的文档,要在一个html文件下使用layui里面的组件库其实很简单,但是在vue项目中使用该ui库却存在着很多坑,下面我们就详细讲解