详解如何使用Pytorch进行多卡训练
作者:实力 发布时间:2023-08-02 06:43:29
Python PyTorch深度学习框架
PyTorch是一个基于Python的深度学习框架,它支持使用CPU和GPU进行高效的神经网络训练。
在大规模任务中,需要使用多个GPU来加速训练过程。
数据并行
“数据并行”是一种常见的使用多卡训练的方法,它将完整的数据集拆分成多份,每个GPU负责处理其中一份,在完成前向传播和反向传播后,把所有GPU的误差累积起来进行更新。数据并行的代码结构如下:
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torch.distributed as dist
import torch.multiprocessing as mp
# 定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(4608, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.pool(x)
x = x.view(-1, 4608)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 定义训练函数
def train(gpu, args):
rank = gpu
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
torch.cuda.set_device(gpu)
train_loader = data.DataLoader(...)
model = Net()
model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
for epoch in range(args.epochs):
epoch_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print('GPU %d Loss: %.3f' % (gpu, epoch_loss))
# 主函数
if __name__ == '__main__':
mp.set_start_method('spawn')
args = parser.parse_args()
args.world_size = args.num_gpus * args.nodes
mp.spawn(train, args=(args,), nprocs=args.num_gpus, join=True)
首先,我们需要在主进程中使用torch.distributed.launch启动多个子进程。每个子进程被分配一个GPU,并调用train函数进行训练。
在train函数中,我们初始化进程组,并将模型以及优化器包装成DistributedDataParallel对象,然后像CPU上一样训练模型即可。在数据并行的过程中,模型和优化器都会被复制到每个GPU上,每个GPU只负责处理一部分的数据。所有GPU上的模型都参与误差累积和梯度更新。
模型并行
“模型并行”是另一种使用多卡训练的方法,它将同一个网络分成多段,不同段分布在不同的GPU上。每个GPU只运行其中的一段网络,并利用前后传播相互连接起来进行训练。代码结构如下:
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import torch.distributed as dist
# 定义模型段
class SubNet(nn.Module):
def __init__(self, in_features, out_features):
super(SubNet, self).__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
return self.linear(x)
# 定义整个模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.subnets = nn.ModuleList([
SubNet(1024, 512),
SubNet(512, 256),
SubNet(256, 100)
])
def forward(self, x):
for subnet in self.subnets:
x = subnet(x)
return x
# 定义训练函数
def train(subnet_id, args):
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=subnet_id)
torch.cuda.set_device(subnet_id)
train_loader = data.DataLoader(...)
model = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
for epoch in range(args.epochs):
epoch_loss = 0.0
for i, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward(retain_graph=True) # 梯度保留,用于后续误差传播
optimizer.step()
epoch_loss += loss.item()
if subnet_id == 0:
print('Epoch %d Loss: %.3f' % (epoch, epoch_loss))
# 主函数
if __name__ == '__main__':
mp.set_start_method('spawn')
args = parser.parse_args()
args.world_size = args.num_gpus * args.subnets
tasks = []
for i in range(args.subnets):
tasks.append(mp.Process(target=train, args=(i, args)))
for task in tasks:
task.start()
for task in tasks:
task.join()
在模型并行中,网络被分成多个子网络,并且每个GPU运行一个子网络。在训练期间,每个子网络的输出会作为下一个子网络的输入。这需要在误差反向传播时,将不同GPU上计算出来的梯度加起来,并再次分发到各个GPU上。
在代码实现中,我们定义了三个子网(SubNet),每个子网有不同的输入输出规模。在train函数中,我们初始化进程组和模型,然后像CPU上一样进行多次迭代训练即可。在反向传播时,将梯度保留并设置retain_graph为True,用于后续误差传播。
来源:https://juejin.cn/post/7223728273007558712


猜你喜欢
- 背景想象一下,现在你有一份Word邀请函模板,然后你有一份客户列表,上面有客户的姓名、联系方式、邮箱等基本信息,然后你的老板现在需要替换邀请
- 最近在公司接到一个需求,里面有一个 * 跳转。类似于选择地址的时候,选择的顺序是:省份->市->区。如果分三个页面跳转,那么体验非
- 简介主要是尝试简单的使用pyhton的爬虫功能,于是使用有道进行尝试,并没有进行深入的诸如相关api的调用。以下是需要的POST数据代码以下
- 我就废话不多说了,还是直接看代码吧!#!/usr/bin/env python3#coding = utf-8def is_triangle
- 前言本节我们讲讲一些简单查询语句示例以及需要注意的地方,简短的内容,深入的理解。EOMONTH在SQL Server 2012的教程示例中,
- 本文为Django项目创建的简单介绍,更为详细的Django项目创建,可以参考如下教程:Django入门与实践-https://www.jb
- Paddle模型性能分析Profiler定位性能瓶颈点优化程序提升性能Paddle Profiler是飞桨框架自带的低开销性能分析器,可以对
- 三种方法利用indexOf判断新数组underscore.js中实际上也是使用的类似的indexOf //传入数组 function uni
- 生成一个2000*5的表格,每个单元格的内容是行号+逗号+列号 方法一:使用createElement生成表格,使用insertRow和in
- 目录wtforms使用1(简单版):使用2(复杂版):wtforms安装:pip3 install wtforms使用1(简单版):from
- 一般情况下TextArea区输入的文字数量是没有限制的,但是我们可以通过javascript限制表单的文字字数。如下javascript代码
- 我们之前要想在调度里面实现延时执行,我们可以使用管道阻塞,直到有人往管道里面写东西才变通畅,还可以使用sleep来睡觉,但是睡觉的过程,协程
- 亲测十分靠谱下面是解决该问题的方法第一步:关闭SIP系统保护1.重启系统时按住Command+R进入恢复模式(记住是你在重新启动时,不是启动
- mysql默认varchar类型是对大小写不敏感(不区分),如果想要mysql区分大小写需要设置排序规则:utf8_bin将字符串中的每一个
- []*int是一个指向指针的切片,本质上是切片,只不过切片里面存放的元素是指针;*[]int是一个指向切片的指针,本质上是指针,可以用*来获
- 安装jenkins安装jenkins很简单,可以用多种方式安装,这里知道的有:在官网下载rpm包,手动安装,最费事centos系统通过yum
- 本文详细罗列并说明了Python的标准库与第三方库如下,供对此有需要的朋友进行参考:Tkinter———— Python默认的图形界面接口。
- 元组的特点:是一种不可变序列,一旦创建就不能修改1、拆包将元组的元素取出赋值给不同变量>>> a = ('hell
- 一:使用where少使用having;二:查两张以上表时,把记录少的放在右边;三:减少对表的访问次数;四:有where子查询时,子查询放在最
- //定义编码header( 'Content-Type:text/html;charset=utf-8 ');//Atomh