网络编程
位置:首页>> 网络编程>> Python编程>> PyTorch Distributed Data Parallel使用详解

PyTorch Distributed Data Parallel使用详解

作者:光火  发布时间:2023-10-26 16:33:45 

标签:PyTorch,Distributed,Data,Parallel,深度学习

DDP

Distributed Data Parallel 简称 DDP,是 PyTorch 框架下一种适用于单机多卡、多机多卡任务的数据并行方式。由于其良好的执行效率及广泛的显卡支持,熟练掌握 DDP 已经成为深度学习从业者所必备的技能之一。本文结合具体代码,详细地说明了 DDP 在项目中的使用方式。读者按照本文所给的范例,只需稍经调试,即可实现 DDP 的整套流程。

概念辨析

具体讲解 DDP 之前,我们先了解了解它和 Data Parallel (DP) 之间的区别。DP 同样是 PyTorch 常见的多 GPU 并行方式之一,且它的实现非常简洁:

# 函数定义
torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
'''
module : 模型
device_ids : 参与训练的 GPU 列表
output_device : 指定输出的 GPU, 通常省略, 即默认使用索引为 0 的显卡
'''
# 程序模板
device_ids = [0, 1]
net = torch.nn.DataParallel(net, device_ids=device_ids)

基本原理及固有缺陷:在 Data Parallel 模式下,数据会被自动切分,加载到 GPU。同时,模型也将拷贝至各个 GPU 进行正向传播。在多个进程之间,会有一个进程充当 master 节点,负责收集各张显卡积累的梯度,并据此更新参数,再统一发送至其他显卡。因此整体而言,master 节点承担了更多的计算与通信任务,容易造成网络堵塞,影响训练速度。

常见问题及解决方案:Data Parallel 要求模型必须在 device_ids[0] 拥有参数及缓冲区,因此当卡 0 被占用时,可以在 nn.DataParallel 之前添加如下代码:

# 按照 PIC_BUS_ID 顺序自 0 开始排列 GPU 设备
os.environ['CUDA_DEVICE_ORDER'] = 'PIC_BUS_ID'
# 设置当前使用的 GPU 为 2、3 号设备
os.environ['CUDA_VISIBLE_DEVICES'] = '2, 3'

如此,device_ids[0] 将被默认为 2 号卡,device_ids[1] 则对应 3 号卡

相较于 DP, Distributed Data Parallel 的实现要复杂得多,但是它的优势也非常明显:

  • DDP 速度更快,可以达到略低于显卡数量的加速比;

  • DDP 可以实现负载的均匀分配,克服了 DP 需要一个进程充当 master 节点的固有缺陷;

  • 采用 DDP 通常可以支持更大的 batch size,不会像 DP 那样出现其他显卡尚有余力,而卡 0 直接 out of memory 的情况;

  • 另外,在 DDP 模式下,输入到 data loader 的 bacth size 不再代表总数,而是每块 GPU 各自负责的 sample 数量。比方说,batch_size = 30,有两块 GPU。在 DP 模式下,每块 GPU 会负责 15 个样本。而在 DDP 模式下,每块 GPU 会各自负责 30 个样本;

  • DDP 基本原理:倘若我们拥有 N 张显卡,则在 Distributed Data Parallel 模式下,就会启动 N 个进程。每个进程在各自的卡上加载模型,且模型的参数完全相同。训练过程中,各个进程通过一种名为 Ring-Reduce 的方式与其他进程通信,交换彼此的梯度,从而获得所有的梯度信息。随后,各个进程利用梯度的平均值更新参数。由于初始值和更新量完全相同,所以各个进程更新后的参数仍保持一致。

常用术语

  • rank

    • 进程号

    • 多进程上下文中,通常假定 rank = 0 为主进程或第一个进程

  • node

    • 物理节点,表示一个容器或一台机器

    • 节点内部可以包含多个 GPU

  • local_rank

    • 一个 node 中,进程的相对序号

    • local_rank 在 node 之间独立

  • world_size

    • 全局进程数

    • 一个分布式任务中 rank 的数量

  • group

    • 进程组

    • 一个分布式任务就对应一个进程组

    • 只有当用户创立多个进程组时,才会用到

PyTorch Distributed Data Parallel使用详解

代码实现

Distributed Data Parallel 可以通过 Python 的 torch.distributed.launch 启动器,在命令行分布式地执行 Python 文件。执行过程中,启动器会将当前进程(其实就是 GPU)的 index 通过参数传递给 Python,而我们可以利用如下方式获取这个 index:

import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', default=-1, type=int,
                   metavar='N', help='Local process rank.')
args = parser.parse_args()
# print(args.local_rank)
# local_rank 表示本地进程序号

随后,初始化进程组。对于在 GPU 执行的任务,建议选择 nccl (由 NVIDIA 推出) 作为通信后端。对于在 CPU 执行的任务,建议选择 gloo (由 Facebook 推出) 作为通信后端。倘若不传入 init_method,则默认为 env://,表示自环境变量读取分布式信息

dist.init_process_group(backend='nccl', init_method='env://')
# 初始化进程组之后, 通常会执行这两行代码
torch.cuda.set_device(args.local_rank)
device = torch.device('cuda', args.local_rank)
# 后续的 model = model.to(device), tensor.cuda(device)
# 对应的都是这里由 args.local_rank 初始化得到的 device

数据部分,使用 Distributed Sampler 划分数据集,并将 sampler 传入 data loader。需要注意的是,此时在 data loader 中不能指定 shuffle 为 True,否则会报错 (sampler 已具备随机打乱功能)

dev_sampler = data.DistributedSampler(dev_data_set)
train_sampler = data.DistributedSampler(train_data_set)
dev_loader = data.DataLoader(dev_data_set, batch_size=dev_batch_size,
                            shuffle=False, sampler=dev_sampler)
train_loader = data.DataLoader(train_data_set, batch_size=train_batch_size,
                              shuffle=False, sampler=train_sampler)

模型部分,首先将将模型送至 device,即对应的 GPU 上,再使用 Distributed Data Parallel 包装模型(顺序颠倒会报错)

model = model.to(device)
model = nn.parallel.DistributedDataParallel(
   model, device_ids=[args.local_rank], output_device=args.local_rank
)

Distributed Data Parallel 模式下,保存模型应使用 net.module.state_dict(),而非 net.state_dict()。且无论是保存模型,还是 LOGGER 打印,只对 local_rank 为 0 的进程操作即可,因此代码中会有很多 args.local_rank == 0 的判断

if args.local_rank == 0:
   LOGGER.info(f'saving latest model: {output_path}')
   torch.save({'model': model.module.state_dict(),
               'optimizer': None, 'epoch': epoch, 'best-f1': best_f1},
              open(os.path.join(output_path, 'latest_model_{}.pth'.format(fold)), 'wb'))

利用 torch.load 加载模型时,设置 map_location=device,否则卡 0 会承担更多的开销

load_model = torch.load(best_path, map_location=device)
model.load_state_dict(load_model['model'])
  • dist.barrier() 可用于同步多个进程,建议只在必要的位置使用,如初始化 DDP 模型之前、权重更新之后、开启新一轮 epoch 之前

  • 计算 accuracy 时,可以使用 dist.all_reduce(score, op=dist.ReduceOp.SUM),将各个进程计算的准确率求平均

  • 计算 f1-score 时,可以使用 dist.all_gather(all_prediction_list, prediction_list),将各个进程获得的预测值和真实值汇总到 all_list,再统一代入公式

启动方式

torch.distributed.launch

# 此处 --nproc_per_node 4 的含义是 server 有 4 张显卡
python torch.distributed.launch --nproc_per_node 4 train.py
# 倘若使用 nohup, 则注意输入命令后 exit 当前终端
python torch.distributed.launch --nproc_per_node 4 train.py
  • torchrun,推荐使用这种方式,因为 torch.distributed.launch 即将弃用

代码中,只需将 Argument Parser 相关的部分替换为

local_rank = int(os.environ['LOCAL_RANK'])

然后将 args.local_rank 全部改为 local_rank 即可

启动命令

# 单机多卡训练时, 可以不指定 nnodes
torchrun --nnodes=1 --nproc_per_node=4 train.py
# 倘若使用 nohup, 则注意输入命令后 exit 当前终端
nohup torchrun --nnodes=1 --nproc_per_node=4 train.py > nohup.out &

来源:https://juejin.cn/post/7211775098310950973

0
投稿

猜你喜欢

  • 上篇博客转载了关于感知器的用法,遂这篇做个大概总结,并实现一个简单的感知器,也为了加深自己的理解。感知器是最简单的神经网络,只有一层。感知器
  • CPU活动展示导入模块,创建画板,创建画笔进行绘画出cpu的数据,一定要用线程,负责会卡住哦实现代码import tkinterfrom t
  • 按照本文操作和体会,会对sql优化有个基本最简单的了解,其他深入还需要更多资料和实践的学习: 1. 建表:  代码如下:creat
  • 解决SQL2000最大流水号的两个好方法问:请问怎样才能解决ms serer 2000 最大流水号的问题?答:我可以介绍两种方法给你:方法1
  • 在上一节《Django是什么》中,我们对 Django 的诞生以及 Web 框架的概念有了基本的了解,本节我们介绍 Django 的设计模式
  • 这阵子没有精力完整翻译和发到译言(  现下正渐入状态,预计写博客量会逐步提升回来),简短做一个概要翻译,为近期工作需要做一个参考。
  • 前言玩博客一个多月了,渐渐发现了一些有意思的事,经常会有人用同样的评论到处刷,不知道是为了加没什么用的积分,还是纯粹为了表达楼主好人。那么问
  • Django是一个基于Python Web框架的高级Web框架,允许快速开发和干净,务实的设计。今天,我们将创建一个待办事项应用程序,以了解
  • 1.包: package PaintBrush; /** * * @author lucifer */ public class Paint
  • 前言在python基础知识中有说过,字典是可变的数据类型,其参数又是键对值。setdefault()方法和字典的get()方法在一些地方比较
  • 本文摘自 《深度学习原理与PyTorch实战》我们将从预测某地的共享单车数量这个实际问题出发,带领读者走进神经网络的殿堂,运用PyTorch
  •         国庆假期快到了,想查查还有几天几小时到假期,这对程序员
  • pandas处理大数据的限制现在的数据科学比赛提供的数据量越来越大,动不动几十个G,甚至上百G,这就要考验机器性能和数据处理能力。Pytho
  • -crop参数是从一个图片截取一个指定区域的子图片.格式如下:convert -crop widthxheight{+-
  • 在使用ORACLE的过程过,我们会经常遇到一些ORACLE产生的错误,对于初学者而言,这些错误可能有点模糊,而且可能一时不知怎么去处理产生的
  • 前言圣诞节快到了,是不是想用python画一个可爱的圣诞树,我在各大网站都查了一下,都不太美观,然后我就学习了一下别人的代码改写了一下,自己
  • 装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能,装饰器的返回值也是一个函数对象.经常用于有切
  • Flask是一个Python编写的Web 微框架,让我们可以使用Python语言快速实现一个网站或Web服务。本文参考自Flask官方文档,
  • J2ME是利用HttpConnection建立HTTP连接,然后获取数据,ASP也是利用HTTP协议,因而可以利用J2ME与ASP建立连接,
  • 需求我在最近的一个任务中,存在一个redis高并发计算多个客户端接收预警信息的时长问题。模型是首先模拟多个客户端连接预警服务器集群,然后向预
手机版 网络编程 asp之家 www.aspxhome.com