pytorch中关于distributedsampler函数的使用
作者:DRACO于 发布时间:2023-01-18 01:10:01
关于distributedsampler函数的使用
1.如何使用这个分布式采样器
在使用distributedsampler函数时,观察loss发现loss收敛有规律,发现是按顺序读取数据,未进行shuffle。
问题的解决方式就是怀疑 seed 有问题,参考源码 DistributedSampler,发现 shuffle 的结果依赖 g.manual_seed(self.epoch) 中的 self.epoch。
def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
而 self.epoch 初始默认是 0
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
但是 DistributedSampler 也提供了一个 set 函数来改变 self.epoch
def set_epoch(self, epoch):
self.epoch = epoch
所以在运行的时候要不断调用这个 set_epoch 函数。只要把我的代码中的
# sampler.set_epoch(e)
全部代码如下:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
input_size = 5
output_size = 2
batch_size = 2
data_size = 16
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, size, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(5), torch.ones(5)*2,
torch.ones(5)*3,torch.ones(5)*4,
torch.ones(5)*5,torch.ones(5)*6,
torch.ones(5)*7,torch.ones(5)*8,
torch.ones(5)*9, torch.ones(5)*10,
torch.ones(5)*11,torch.ones(5)*12,
torch.ones(5)*13,torch.ones(5)*14,
torch.ones(5)*15,torch.ones(5)*16]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(input_size, data_size, local_rank)
sampler = DistributedSampler(dataset)
rand_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
sampler=sampler)
e = 0
while e < 2:
t = 0
# sampler.set_epoch(e)
for data in rand_loader:
print(data)
e+=1
运行:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 test.py
2.关于用不用这个采样器的区别
多卡去训模型,尝试着用DDP模式,而不是DP模式去加速训练(很容易出现负载不均衡的情况)。
遇到了一点关于DistributedSampler这个采样器的一点疑惑,想试验下在DDP模式下,使用这个采样器和不使用这个采样器有什么区别。
实验代码:
整个数据集大小为8,batch_size 为4,总共跑2个epoch。
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
torch.distributed.init_process_group(backend="nccl")
batch_size = 4
data_size = 8
local_rank = torch.distributed.get_rank()
print(local_rank)
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
class RandomDataset(Dataset):
def __init__(self, length, local_rank):
self.len = length
self.data = torch.stack([torch.ones(1), torch.ones(1)*2,torch.ones(1)*3,torch.ones(1)*4,torch.ones(1)*5,torch.ones(1)*6,torch.ones(1)*7,torch.ones(1)*8]).to('cuda')
self.local_rank = local_rank
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
dataset = RandomDataset(data_size, local_rank)
sampler = DistributedSampler(dataset)
#rand_loader =DataLoader(dataset=dataset,batch_size=batch_size,sampler=None,shuffle=True)
rand_loader = DataLoader(dataset=dataset,batch_size=batch_size,sampler=sampler)
epoch = 0
while epoch < 2:
sampler.set_epoch(epoch)
for data in rand_loader:
print(data)
epoch+=1
运行命令:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 test.py
实验结果:
结论分析:上面的运行结果来看,在一个epoch中,sampler相当于把整个数据集 划分成了nproc_per_node份,每个GPU每次得到batch_size的数量,也就是nproc_per_node 个GPU分一整份数据集,总数据量大小就为1个dataset。
如果不用它里面自带的sampler,单纯的还是按照我们一般的形式。Sampler=None,shuffle=True这种,那么结果将会是下面这样的:
结果分析:没用sampler的话,在一个epoch中,每个GPU各自维护着一份数据,每个GPU每次得到的batch_size的数据,总的数据量为2个dataset,
来源:https://blog.csdn.net/chanbo8205/article/details/115242635
猜你喜欢
- 对于任何一个开发项目来说最大的错误可能就是没有计划。最近,有些人认为开始前无需计划,一个优秀的开发者需要的是随机应变。我敢肯定这样的做法最后
- 我简单的绘制了一下排序算法的分类,蓝色字体的排序算法是我们用python3实现的,也是比较常用的排序算法。Python3常用排序算法1、Py
- 优化糟糕设计的表结果或者索引能很大程度改进mysql的性能。 如果需要高性能, 那么就需要根据不同的操作需求精心设计表结构和索引, 这当然需
- 生活中经常会碰到多个excel表格汇总成一个表格的情况,比如你发放了一份表格让班级所有同学填写,而你负责将大家的结果合并成一个。诸如此类的问
- 本文实例讲述了python判断字符串是否包含子字符串的方法。分享给大家供大家参考。具体如下:python的string对象没有contain
- 使用命令行时,如果要添加选项的话,python 2.3里新增加了一个模块叫optparse,也是专门来处理命令行选项的。from optpa
- 在python处理数据时,经常用到DataFrame和set。train=pd.read_csv('XXX.csv')#读取
- 前言:交换机模式主要包括:交换机之发布订阅、交换机之关键字和交换机之通配符。1、交换机之发布订阅 发布订阅和简单的消息队列区别在于
- try 块允许您测试代码块以查找错误。except 块允许您处理错误。finally 块允许您执行代码,无论 try 和 except 块的
- 很多设计师都会遇到这样的问题。一个产品会有很多种方式去包装,其中包括很多功能和很多体验。功能越多会被认为越实用,体验越好会被认为越方便。方便
- 本文实例讲述了Yii2中SqlDataProvider用法。分享给大家供大家参考,具体如下:第一种方法:$totalCount = Yii:
- 一、包的导入Golang 当导入多个包时,一般按照字母顺序排列包名称,像Goland 等IDE 会在保存文件时自动完成这个动作。Golang
- 教育信息化时代,考试成绩也要求上网公布。一次我将考试成绩制作成一个HTML文件,如图1所示,领导审查的意见是“将成绩按名次排列”,可是所有的
- 使用pandas读取xml文件报错“ Unsupported format, or corrupt file: Expected BOF r
- PHP版: $o = array('x'=>1, 'y'=>2, 'z'=>
- 保存Python程序,可以使用以下方法:使用编辑器编写代码并保存1、打开Notepad++2、需要新建文本时,点击左上角”文本“,在弹出的菜
- 开始刷leetcode算法题 今天做的是“买卖股票的最佳时机”题目要求 给定一个数组,它的第 i 个元素是一支给定股票第 i 天的价格。设计
- 现在不写asp了这次我将我以前沉淀下的一些函数库共享给大家,希望能给初学者启示,给老手也有所帮助吧.先谢谢大家支持! <%@
- python实现情感分析(Word2Vec)** 前几天跟着老师做了几个项目,老师写的时候劈里啪啦一顿敲,写了个啥咱也布吉岛,线下自己就瞎琢
- 建造者模式的适用范围:想要创建一个由多个部分组成的对象,而且它的构成需要一步接一步的完成。只有当各个部分都完成了,这个对象才完整。建造者模式