Pytorch distributed 多卡并行载入模型操作
作者:orientliu96 发布时间:2023-03-01 10:17:09
标签:Pytorch,distributed,多卡,并行
一、Pytorch distributed 多卡并行载入模型
这次来介绍下如何载入模型。
目前没有找到官方的distribute 载入模型的方式,所以采用如下方式。
大部分情况下,我们在测试时不需要多卡并行计算。
所以,我在测试时只使用单卡。
from collections import OrderedDict
device = torch.device("cuda")
model = DGCNN(args).to(device) #自己的模型
state_dict = torch.load(args.model_path) #存放模型的位置
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model.load_state_dict (new_state_dict)
二、pytorch DistributedParallel进行单机多卡训练
One_导入库:
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
Two_进程初始化:
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank', type=int, default=-1)
# 添加必要参数
# local_rank:系统自动赋予的进程编号,可以利用该编号控制打印输出以及设置device
torch.distributed.init_process_group(backend="nccl", init_method='file://shared/sharedfile',
rank=local_rank, world_size=world_size)
# world_size:所创建的进程数,也就是所使用的GPU数量
# (初始化设置详见参考文档)
Three_数据分发:
dataset = datasets.ImageFolder(dataPath)
data_sampler = DistributedSampler(dataset, rank=local_rank, num_replicas=world_size)
# 使用DistributedSampler来为各个进程分发数据,其中num_replicas与world_size保持一致,用于将数据集等分成不重叠的数个子集
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1,drop_last=True, pin_memory=True, sampler=data_sampler)
# 在Dataloader中指定sampler时,其中的shuffle必须为False,而DistributedSampler中的shuffle项默认为True,因此训练过程默认执行shuffle
Four_网络模型:
torch.cuda.set_device(local_rank)
device = torch.device('cuda:'+f'{local_rank}')
# 设置每个进程对应的GPU设备
D = Model()
D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(D).to(device)
# 由于在训练过程中各卡的前向后向传播均独立进行,因此无法进行统一的批归一化,如果想要将各卡的输出统一进行批归一化,需要将模型中的BN转换成SyncBN
D = torch.nn.parallel.DistributedDataParallel(
D, find_unused_parameters=True, device_ids=[local_rank], output_device=local_rank)
# 如果有forward的返回值如果不在计算loss的计算图里,那么需要find_unused_parameters=True,即返回值不进入backward去算grad,也不需要在不同进程之间进行通信。
Five_迭代:
data_sampler.set_epoch(epoch)
# 每个epoch需要为sampler设置当前epoch
Six_加载:
dist.barrier()
D.load_state_dict(torch.load('D.pth'), map_location=torch.device('cpu'))
dist.barrier()
# 加载模型前后用dist.barrier()来同步不同进程间的快慢
Seven_启动:
CUDA_VISIBLE_DEVICES=1,3 python -m torch.distributed.launch --nproc_per_node=2 train.py --epochs 15000 --batchsize 10 --world_size 2
# 用-m torch.distributed.launch启动,nproc_per_node为所使用的卡数,batchsize设置为每张卡各自的批大小
来源:https://blog.csdn.net/Orientliu96/article/details/104702520


猜你喜欢
- 本文实例讲述了PHP实现判断二叉树是否对称的方法。分享给大家供大家参考,具体如下:问题请实现一个函数,用来判断一颗二叉树是不是对称的。注意,
- 用法:DataFrame.drop(labels=None,axis=0, index=None, columns=None, inplac
- IE6这个东东在前端开发者的眼中恐怕都是一个恶梦之地,我说它万恶想来没人反对吧。依据现在卡当网的访问统计数据来看,从IE6来的访问量还是占到
- 读取docx文档使用的包是python-docx1. 安装python-docx包sudo pip install python-docx2
- 出图是项目里常见的任务,有的项目甚至会要上百张图片,所以批量出土工具很有必要。arcpy.mapping就是ArcGIS里的出图模块,能快速
- 关于如何区分艺术和设计的话题总是玄之又玄,并因此引发的争论也有很长一段时间。艺术家和设计师都基于相同的知识基础来创作视觉作品,但他们创作的理
- 很久没有发表文章了,最近一直在研究产品设计标准的问题,之前有发过一篇关于 Axure的教程 ,相信很多人已经学会如何使用,这次我给大家介绍一
- itchat是一个开源的微信个人号接口,可以使用该库进行微信网页版中的所有操作,比如:所有好友、添加好友、拉好友群聊、微信机器人等等。详细用
- “用户体验”作为舶来品在国内风靡已经有几个年头了,而且从目前情况来看仍旧会继续风靡一段时间。当某产品发布会上,发言人张口闭口就
- python是支持多线程的,主要是通过thread和threading这两个模块来实现的。thread模块是比较底层的模块,th
- 本文实例讲述了Python列表list内建函数用法。分享给大家供大家参考,具体如下:#coding=utf8'''&
- class Helper_Page{ /** 总信息数 */ var $infoCount; /** 总页数 */ var $pageCou
- 背景在实际项目实施中,会编写很多在服务器执行的作业脚本。程序中凡是涉及到数据库链接、操作系统用户链接、IP地址、主机名称的内容都是敏感信息。
- 这篇文章主要介绍了Python实现图片批量加入水印代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- 现在有一id=test的下拉框,怎么拿到选中的那个值呢? 分别使用javascript原生的方法和jquery方法 <select i
- 1。如果客户端和服务器端的连接需要跨越并通过不可信任的网络,那么就需要使用SSH隧道来加密该连接的通信。 2。用set password语句
- 关于建立索引的几个准则:1、合理的建立索引能够加速数据读取效率,不合理的建立索引反而会拖慢数据库的响应速度。2、索引越多,更新数据的速度越慢
- 问题:将文件夹a下任意命名的10个文件修改为如下图所示文件?代码:#coding:utf-8import ospath = "./
- 前言许多 Web 应用依赖大量的 I/O (输入/输出) 操作,比如从网站上下载图片、视频等内容;进行网络聊天或者针对后台数据库进行多次查询
- 本文介绍基于Python语言,将一个Excel表格文件中的数据导入到Python中,并将其通过字典格式来存储的方法~ &a