PyTorch Distributed Data Parallel使用详解
作者:光火 发布时间:2023-10-26 16:33:45
DDP
Distributed Data Parallel 简称 DDP
,是 PyTorch 框架下一种适用于单机多卡、多机多卡任务的数据并行方式。由于其良好的执行效率及广泛的显卡支持,熟练掌握 DDP
已经成为深度学习从业者所必备的技能之一。本文结合具体代码,详细地说明了 DDP
在项目中的使用方式。读者按照本文所给的范例,只需稍经调试,即可实现 DDP
的整套流程。
概念辨析
具体讲解 DDP
之前,我们先了解了解它和 Data Parallel (DP
) 之间的区别。DP
同样是 PyTorch 常见的多 GPU 并行方式之一,且它的实现非常简洁:
# 函数定义
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
'''
module : 模型
device_ids : 参与训练的 GPU 列表
output_device : 指定输出的 GPU, 通常省略, 即默认使用索引为 0 的显卡
'''
# 程序模板
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)
基本原理及固有缺陷:在 Data Parallel 模式下,数据会被自动切分,加载到 GPU。同时,模型也将拷贝至各个 GPU 进行正向传播。在多个进程之间,会有一个进程充当 master 节点,负责收集各张显卡积累的梯度,并据此更新参数,再统一发送至其他显卡。因此整体而言,master 节点承担了更多的计算与通信任务,容易造成网络堵塞,影响训练速度。
常见问题及解决方案:Data Parallel 要求模型必须在 device_ids[0] 拥有参数及缓冲区,因此当卡 0 被占用时,可以在 nn.DataParallel
之前添加如下代码:
# 按照 PIC_BUS_ID 顺序自 0 开始排列 GPU 设备
os.environ['CUDA_DEVICE_ORDER'] = 'PIC_BUS_ID'
# 设置当前使用的 GPU 为 2、3 号设备
os.environ['CUDA_VISIBLE_DEVICES'] = '2, 3'
如此,device_ids[0] 将被默认为 2 号卡,device_ids[1] 则对应 3 号卡
相较于 DP
, Distributed Data Parallel 的实现要复杂得多,但是它的优势也非常明显:
DDP
速度更快,可以达到略低于显卡数量的加速比;DDP
可以实现负载的均匀分配,克服了DP
需要一个进程充当 master 节点的固有缺陷;采用
DDP
通常可以支持更大的 batch size,不会像DP
那样出现其他显卡尚有余力,而卡 0 直接 out of memory 的情况;另外,在
DDP
模式下,输入到 data loader 的 bacth size 不再代表总数,而是每块 GPU 各自负责的 sample 数量。比方说,batch_size = 30,有两块 GPU。在DP
模式下,每块 GPU 会负责 15 个样本。而在DDP
模式下,每块 GPU 会各自负责 30 个样本;DDP
基本原理:倘若我们拥有 N 张显卡,则在 Distributed Data Parallel 模式下,就会启动 N 个进程。每个进程在各自的卡上加载模型,且模型的参数完全相同。训练过程中,各个进程通过一种名为 Ring-Reduce 的方式与其他进程通信,交换彼此的梯度,从而获得所有的梯度信息。随后,各个进程利用梯度的平均值更新参数。由于初始值和更新量完全相同,所以各个进程更新后的参数仍保持一致。
常用术语
rank
进程号
多进程上下文中,通常假定 rank = 0 为主进程或第一个进程
node
物理节点,表示一个容器或一台机器
节点内部可以包含多个 GPU
local_rank
一个 node 中,进程的相对序号
local_rank 在 node 之间独立
world_size
全局进程数
一个分布式任务中 rank 的数量
group
进程组
一个分布式任务就对应一个进程组
只有当用户创立多个进程组时,才会用到
代码实现
Distributed Data Parallel 可以通过 Python 的 torch.distributed.launch
启动器,在命令行分布式地执行 Python 文件。执行过程中,启动器会将当前进程(其实就是 GPU)的 index 通过参数传递给 Python,而我们可以利用如下方式获取这个 index:
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int,
metavar='N', help='Local process rank.')
args = parser.parse_args()
# print(args.local_rank)
# local_rank 表示本地进程序号
随后,初始化进程组。对于在 GPU 执行的任务,建议选择 nccl
(由 NVIDIA 推出) 作为通信后端。对于在 CPU 执行的任务,建议选择 gloo
(由 Facebook 推出) 作为通信后端。倘若不传入 init_method
,则默认为 env://
,表示自环境变量读取分布式信息
dist.init_process_group(backend='nccl', init_method='env://')
# 初始化进程组之后, 通常会执行这两行代码
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
# 后续的 model = model.to(device), tensor.cuda(device)
# 对应的都是这里由 args.local_rank 初始化得到的 device
数据部分,使用 Distributed Sampler 划分数据集,并将 sampler 传入 data loader。需要注意的是,此时在 data loader 中不能指定 shuffle 为 True,否则会报错 (sampler 已具备随机打乱功能)
dev_sampler = data.DistributedSampler(dev_data_set)
train_sampler = data.DistributedSampler(train_data_set)
dev_loader = data.DataLoader(dev_data_set, batch_size=dev_batch_size,
shuffle=False, sampler=dev_sampler)
train_loader = data.DataLoader(train_data_set, batch_size=train_batch_size,
shuffle=False, sampler=train_sampler)
模型部分,首先将将模型送至 device,即对应的 GPU 上,再使用 Distributed Data Parallel 包装模型(顺序颠倒会报错)
model = model.to(device)
model = nn.parallel.DistributedDataParallel(
model, device_ids=[args.local_rank], output_device=args.local_rank
)
Distributed Data Parallel 模式下,保存模型应使用 net.module.state_dict()
,而非 net.state_dict()
。且无论是保存模型,还是 LOGGER 打印,只对 local_rank 为 0 的进程操作即可,因此代码中会有很多 args.local_rank == 0
的判断
if args.local_rank == 0:
LOGGER.info(f'saving latest model: {output_path}')
torch.save({'model': model.module.state_dict(),
'optimizer': None, 'epoch': epoch, 'best-f1': best_f1},
open(os.path.join(output_path, 'latest_model_{}.pth'.format(fold)), 'wb'))
利用 torch.load
加载模型时,设置 map_location=device
,否则卡 0 会承担更多的开销
load_model = torch.load(best_path, map_location=device)
model.load_state_dict(load_model['model'])
dist.barrier()
可用于同步多个进程,建议只在必要的位置使用,如初始化DDP
模型之前、权重更新之后、开启新一轮 epoch 之前计算 accuracy 时,可以使用
dist.all_reduce(score, op=dist.ReduceOp.SUM)
,将各个进程计算的准确率求平均计算 f1-score 时,可以使用
dist.all_gather(all_prediction_list, prediction_list)
,将各个进程获得的预测值和真实值汇总到 all_list,再统一代入公式
启动方式
torch.distributed.launch
# 此处 --nproc_per_node 4 的含义是 server 有 4 张显卡
python torch.distributed.launch --nproc_per_node 4 train.py
# 倘若使用 nohup, 则注意输入命令后 exit 当前终端
python torch.distributed.launch --nproc_per_node 4 train.py
torchrun
,推荐使用这种方式,因为torch.distributed.launch
即将弃用
代码中,只需将 Argument Parser 相关的部分替换为
local_rank = int(os.environ['LOCAL_RANK'])
然后将 args.local_rank
全部改为 local_rank
即可
启动命令
# 单机多卡训练时, 可以不指定 nnodes
torchrun --nnodes=1 --nproc_per_node=4 train.py
# 倘若使用 nohup, 则注意输入命令后 exit 当前终端
nohup torchrun --nnodes=1 --nproc_per_node=4 train.py > nohup.out &
来源:https://juejin.cn/post/7211775098310950973


猜你喜欢
- 背景需要遍历结构体的所有field对于exported的field, 动态set这个field的value对于unexported的fiel
- 近日在项目中遇到一个问题: 如何在报表中统计JSON格式存储的数据?例如有个调查问卷记录表,记录每个问题的答案。 其结构示意如下(横表设计)
- 在批评Python的讨论中,常常说起Python多线程是多么的难用。还有人对 global interpreter lock(也被亲切的称为
- SQLServer中建立与服务器的连接时出错的解决方案如下:步骤1:在SQLServer 实例上启用远程连接1.指向“开始->程序-&
- 单个表的删除:DELETE FROM tableName WHERE columnName = value;删除表内的所有行:即:保留表的结
- pyqtgraph官方给的示例居然会报错2333官方文档传送门:#####pyqtgraph exportpyqtgraph支持在可视化窗口
- 参照资料:selenium webdriver添加cookie: https://www.jb51.net/article/193102.h
- 前言异步函数也是有执行顺序的。本质上来说,JavaScript是单线程语言,不管是在浏览器中还是nodejs环境下。浏览器在执行js代码和渲
- 作为设计师,我们都知道,一个极简的设计可以实现漂亮的效果。然而,很多设计师在实现上有些麻烦:要么是没有时间让使用如此少的元素制作的页面看起来
- 在JAVA WEB应用中,如何获取servlet请求中的参数,并传递给跳转的JSP页面?例如访问http://localhost:8088/
- 一、增强的可扩展性 Oracle9i Real Application Clusters是Oracle的下一代并行服务器系列产品。Oracl
- 本文实例讲述了python生成IP段的方法。分享给大家供大家参考。具体实现方法如下:#!/usr/local/bin/python#-*-
- 在《多线程与同步》中介绍了多线程及存在的问题,而通过使用多进程而非线程可有效地绕过全局解释器锁。 因此,通过multiprocessing模
- 本文实例讲述了Python中bisect的用法,是一个比较常见的实用技巧。分享给大家供大家参考。具体分析如下:一般来说,Python中的bi
- 最早大家都没有给链接加title的习惯,后来因为w3c标准普及,又集体加上了title。从一个极端走到另个极端,于是出现很多怪异现象。两方面
- 本文实例讲述了Flask框架使用DBUtils模块连接数据库的操作方法。分享给大家供大家参考,具体如下:Flask连接数据库数据库连接池:D
- 前言gif图就是动态图,它的原理和视频有点类似,也是通过很多静态图片合成的.本篇文章主要介绍,如何利用Python快速合成gif图,主要利用
- 利用python+ffmpeg合并B站视频及格式转换 B站客户端下载的视频一般有两种格式:早期的多为blv格式(由flv格式转换而来,音视频
- 众所周知,如果py文件不在当前路径,那么就不能import,因此,本文介绍如下两种有效的方法:方法1:修改环境变量,在~/.bashrc里面
- /* 建立数据表 */ create table td_base_data( id int(10) not null auto_increm