网络编程
位置:首页>> 网络编程>> Python编程>> pytorch使用horovod多gpu训练的实现

pytorch使用horovod多gpu训练的实现

作者:You-wh  发布时间:2022-01-07 16:01:18 

标签:pytorch,horovod,gpu

pytorch在Horovod上训练步骤分为以下几步:


import torch
import horovod.torch as hvd

# Initialize Horovod 初始化horovod
hvd.init()

# Pin GPU to be used to process local rank (one GPU per process) 分配到每个gpu上
torch.cuda.set_device(hvd.local_rank())

# Define dataset... 定义dataset
train_dataset = ...

# Partition dataset among workers using DistributedSampler 对dataset的采样器进行调整,使用torch.utils.data.distributed.DistributedSampler
train_sampler = torch.utils.data.distributed.DistributedSampler(
 train_dataset, num_replicas=hvd.size(), rank=hvd.rank())

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=..., sampler=train_sampler)

# Build model...
model = ...
model.cuda()

optimizer = optim.SGD(model.parameters())

# Add Horovod Distributed Optimizer 使用Horovod的分布式优化器函数包裹在原先optimizer上
optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())

# Broadcast parameters from rank 0 to all other processes. 参数广播到每个gpu上
hvd.broadcast_parameters(model.state_dict(), root_rank=0)

for epoch in range(100):
 for batch_idx, (data, target) in enumerate(train_loader):
   optimizer.zero_grad()
   output = model(data)
   loss = F.nll_loss(output, target)
   loss.backward()
   optimizer.step()
   if batch_idx % args.log_interval == 0:
     print('Train Epoch: {} [{}/{}]\tLoss: {}'.format(
       epoch, batch_idx * len(data), len(train_sampler), loss.item()))

完整示例代码如下,在imagenet上采用resnet50进行训练


 from __future__ import print_function

import torch
 import argparse
 import torch.backends.cudnn as cudnn
 import torch.nn.functional as F
 import torch.optim as optim
 import torch.utils.data.distributed
 from torchvision import datasets, transforms, models
import horovod.torch as hvd
import os
import math
from tqdm import tqdm
from distutils.version import LooseVersion

# Training settings
parser = argparse.ArgumentParser(description='PyTorch ImageNet Example',
                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--train-dir', default=os.path.expanduser('~/imagenet/train'),
          help='path to training data')
parser.add_argument('--val-dir', default=os.path.expanduser('~/imagenet/validation'),
          help='path to validation data')
parser.add_argument('--log-dir', default='./logs',
          help='tensorboard log directory')
parser.add_argument('--checkpoint-format', default='./checkpoint-{epoch}.pth.tar',
          help='checkpoint file format')
parser.add_argument('--fp-allreduce', action='store_true', default=False,
          help='use fp compression during allreduce')
parser.add_argument('--batches-per-allreduce', type=int, default=,
          help='number of batches processed locally before '
             'executing allreduce across workers; it multiplies '
             'total batch size.')
parser.add_argument('--use-adasum', action='store_true', default=False,
          help='use adasum algorithm to do reduction')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
          help='input batch size for training')
parser.add_argument('--val-batch-size', type=int, default=32,
          help='input batch size for validation')
parser.add_argument('--epochs', type=int, default=90,
          help='number of epochs to train')
parser.add_argument('--base-lr', type=float, default=0.0125,
44           help='learning rate for a single GPU')
45 parser.add_argument('--warmup-epochs', type=float, default=5,
          help='number of warmup epochs')
parser.add_argument('--momentum', type=float, default=0.9,
          help='SGD momentum')
parser.add_argument('--wd', type=float, default=0.00005,
          help='weight decay')

parser.add_argument('--no-cuda', action='store_true', default=False,
          help='disables CUDA training')
parser.add_argument('--seed', type=int, default=42,
          help='random seed')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

allreduce_batch_size = args.batch_size * args.batches_per_allreduce

hvd.init()
torch.manual_seed(args.seed)

if args.cuda:
  # Horovod: pin GPU to local rank.
  torch.cuda.set_device(hvd.local_rank())
  torch.cuda.manual_seed(args.seed)

cudnn.benchmark = True

# If set > 0, will resume training from a given checkpoint.
resume_from_epoch = 0
for try_epoch in range(args.epochs, 0, -1):
  if os.path.exists(args.checkpoint_format.format(epoch=try_epoch)):
    resume_from_epoch = try_epoch
    break

# Horovod: broadcast resume_from_epoch from rank 0 (which will have
# checkpoints) to other ranks.
resume_from_epoch = hvd.broadcast(torch.tensor(resume_from_epoch), root_rank=0,
                 name='resume_from_epoch').item()

# Horovod: print logs on the first worker.
verbose = 1 if hvd.rank() == 0 else 0

# Horovod: write TensorBoard logs on first worker.
try:
  if LooseVersion(torch.__version__) >= LooseVersion('1.2.0'):
    from torch.utils.tensorboard import SummaryWriter
  else:
    from tensorboardX import SummaryWriter
  log_writer = SummaryWriter(args.log_dir) if hvd.rank() == 0 else None
except ImportError:
  log_writer = None

# Horovod: limit # of CPU threads to be used per worker.
torch.set_num_threads(4)

kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
train_dataset = \
  datasets.ImageFolder(args.train_dir,
            transform=transforms.Compose([
              transforms.RandomResizedCrop(224),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor(),
              transforms.Normalize(mean=[., ., .],
                         std=[0.229, 0.224, 0.225])
            ]))
# Horovod: use DistributedSampler to partition data among workers. Manually specify
# `num_replicas=hvd.size()` and `rank=hvd.rank()`.
train_sampler = torch.utils.data.distributed.DistributedSampler(
  train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
train_loader = torch.utils.data.DataLoader(
  train_dataset, batch_size=allreduce_batch_size,
  sampler=train_sampler, **kwargs)

val_dataset = \
  datasets.ImageFolder(args.val_dir,
            transform=transforms.Compose([
              transforms.Resize(256),
              transforms.CenterCrop(224),
              transforms.ToTensor(),
              transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
            ]))
val_sampler = torch.utils.data.distributed.DistributedSampler(
  val_dataset, num_replicas=hvd.size(), rank=hvd.rank())
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size,
                    sampler=val_sampler, **kwargs)

# Set up standard ResNet-50 model.
model = models.resnet50()

# By default, Adasum doesn't need scaling up learning rate.
# For sum/average with gradient Accumulation: scale learning rate by batches_per_allreduce
lr_scaler = args.batches_per_allreduce * hvd.size() if not args.use_adasum else 1

if args.cuda:
  # Move model to GPU.
  model.cuda()
  # If using GPU Adasum allreduce, scale learning rate by local_size.
  if args.use_adasum and hvd.nccl_built():
    lr_scaler = args.batches_per_allreduce * hvd.local_size()

# Horovod: scale learning rate by the number of GPUs.
optimizer = optim.SGD(model.parameters(),
           lr=(args.base_lr *
             lr_scaler),
           momentum=args.momentum, weight_decay=args.wd)

# Horovod: (optional) compression algorithm.
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none

# Horovod: wrap optimizer with DistributedOptimizer.
optimizer = hvd.DistributedOptimizer(
  optimizer, named_parameters=model.named_parameters(),
  compression=compression,
  backward_passes_per_step=args.batches_per_allreduce,
  op=hvd.Adasum if args.use_adasum else hvd.Average)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
if resume_from_epoch > 0 and hvd.rank() == 0:
  filepath = args.checkpoint_format.format(epoch=resume_from_epoch)
  checkpoint = torch.load(filepath)
  model.load_state_dict(checkpoint['model'])
  optimizer.load_state_dict(checkpoint['optimizer'])

# Horovod: broadcast parameters & optimizer state.
hvd.broadcast_parameters(model.state_dict(), root_rank=)
hvd.broadcast_optimizer_state(optimizer, root_rank=)

def train(epoch):
  model.train()
  train_sampler.set_epoch(epoch)
  train_loss = Metric('train_loss')
  train_accuracy = Metric('train_accuracy')

with tqdm(total=len(train_loader),
       desc='Train Epoch   #{}'.format(epoch + 1),
       disable=not verbose) as t:
    for batch_idx, (data, target) in enumerate(train_loader):
      adjust_learning_rate(epoch, batch_idx)

if args.cuda:
        data, target = data.cuda(), target.cuda()
      optimizer.zero_grad()
      # Split data into sub-batches of size batch_size
      for i in range(0, len(data), args.batch_size):
        data_batch = data[i:i + args.batch_size]
        target_batch = target[i:i + args.batch_size]
        output = model(data_batch)
        train_accuracy.update(accuracy(output, target_batch))
        loss = F.cross_entropy(output, target_batch)
        train_loss.update(loss)
        # Average gradients among sub-batches
        loss.div_(math.ceil(float(len(data)) / args.batch_size))
        loss.backward()
      # Gradient is applied across all ranks
      optimizer.step()
      t.set_postfix({'loss': train_loss.avg.item(),
             'accuracy': 100. * train_accuracy.avg.item()})
      t.update(1)

if log_writer:
    log_writer.add_scalar('train/loss', train_loss.avg, epoch)
    log_writer.add_scalar('train/accuracy', train_accuracy.avg, epoch)

def validate(epoch):
  model.eval()
  val_loss = Metric('val_loss')
  val_accuracy = Metric('val_accuracy')

with tqdm(total=len(val_loader),
       desc='Validate Epoch #{}'.format(epoch + ),
       disable=not verbose) as t:
    with torch.no_grad():
      for data, target in val_loader:
        if args.cuda:
          data, target = data.cuda(), target.cuda()
        output = model(data)

val_loss.update(F.cross_entropy(output, target))
        val_accuracy.update(accuracy(output, target))
        t.set_postfix({'loss': val_loss.avg.item(),
               'accuracy': 100. * val_accuracy.avg.item()})
       t.update(1)

if log_writer:
    log_writer.add_scalar('val/loss', val_loss.avg, epoch)
    log_writer.add_scalar('val/accuracy', val_accuracy.avg, epoch)

# Horovod: using `lr = base_lr * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = base_lr` ---> `lr = base_lr * hvd.size()` during
# the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
# After the warmup reduce learning rate by 10 on the 30th, 60th and 80th epochs.
def adjust_learning_rate(epoch, batch_idx):
  if epoch < args.warmup_epochs:
    epoch += float(batch_idx + 1) / len(train_loader)
    lr_adj = 1. / hvd.size() * (epoch * (hvd.size() - 1) / args.warmup_epochs + 1)
  elif epoch < 30:
    lr_adj = 1.
  elif epoch < 60:
    lr_adj = 1e-1
  elif epoch < 80:
    lr_adj = 1e-2
  else:
    lr_adj = 1e-3
  for param_group in optimizer.param_groups:
    param_group['lr'] = args.base_lr * hvd.size() * args.batches_per_allreduce * lr_adj

def accuracy(output, target):
  # get the index of the max log-probability
  pred = output.max(1, keepdim=True)[1]
  return pred.eq(target.view_as(pred)).cpu().float().mean()

def save_checkpoint(epoch):
  if hvd.rank() == 0:
    filepath = args.checkpoint_format.format(epoch=epoch + 1)
    state = {
      'model': model.state_dict(),
      'optimizer': optimizer.state_dict(),
    }
    torch.save(state, filepath)

# Horovod: average metrics from distributed training.
class Metric(object):
  def __init__(self, name):
    self.name = name
    self.sum = torch.tensor(0.)
    self.n = torch.tensor(0.)

def update(self, val):
    self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
    self.n += 1

@property
  def avg(self):
    return self.sum / self.n

for epoch in range(resume_from_epoch, args.epochs):
  train(epoch)
  validate(epoch)
  save_checkpoint(epoch)

来源:https://www.cnblogs.com/ywheunji/p/12298518.html

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com