Pytorch 多块GPU的使用详解
作者:R.X.Geng 发布时间:2021-01-21 09:19:09
标签:Pytorch,GPU
注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。
在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。
1. 设置需要使用的GPU编号
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]
比如我们需要使用第0和第4块GPU,只用上述三行代码即可。
其中第二行指程序只能看到第1块和第4块GPU;
第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。
2.更改网络,可以理解为将网络放入GPU
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential(
......
)
......
self.out = nn.Linear(Liner_input,2)
......
def forward(self,x):
x = self.conv1(x)
......
output = self.out(x)
return output,x
cnn = CNN()
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()
3. 更改输出数据(如向量/矩阵/张量):
for epoch in range(EPOCH):
epoch_loss = 0.
for i, data in enumerate(train_loader2):
image = data['image'] # data是字典,我们需要改的是其中的image
#############更改!!!##################
image = Variable(image).float().cuda()
############################################
label = inputs['label']
#############更改!!!##################
label = Variable(label).type(torch.LongTensor).cuda()
############################################
label = label.resize(BATCH_SIZE)
output = cnn(image)[0]
loss = loss_func(output, label) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step()
... ...
4. 更改其他CPU与GPU冲突的地方
有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。
若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:
... ...
#################################################
pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))
假如不加.cpu()便会报错,此时再改即可。
5. 更改前向传播函数,从而使用多块GPU
以VGG为例:
class VGG(nn.Module):
def __init__(self, features, num_classes=2, init_weights=True):
super(VGG, self).__init__()
... ...
def forward(self, x):
#x = self.features(x)
#################Multi GPUS#############################
x = nn.parallel.data_parallel(self.features,x,ids)
x = x.view(x.size(0), -1)
# x = self.classifier(x)
x = nn.parallel.data_parallel(self.classifier,x,ids)
return x
... ...
然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:
可以看到0和4都被使用啦
来源:https://blog.csdn.net/qazwsxrx/article/details/89672578
0
投稿
猜你喜欢
- 下面,我们就从当前时间来取得随机数,调用的时候用包含文件就可以了:<!--#INCLUDE VIRTUAL="/q
- 废话不多说,我就直接上代码让大家看看吧!#!/usr/bin/env python# -*- coding: utf-8 -*-# @Fil
- Dim iSet conn=Server.CreateObject("ADODB.Connecti
- 想必Java 的开发者没有不知道或者没用过 jps 这个命令的,这个命令是用来在主机上查看有哪些 Java 程序在运行的。我刚用 Go 语言
- 列表转化为字符串如下所示:>>> list1=['ak','uk',4]>>&
- 正则表达式,贪婪匹配与非贪婪匹配正则表达式前戏以某app注册页面获取手机号为例. 其有很多校验规则: 国内手机号必须是11位,纯数字,是常规
- 1.find函数find() 方法检测字符串中是否包含子字符串 str ,如果指定 beg(开始) 和 end(结束) 范围,则检查是否包含
- 实战场景本篇博客为大家介绍一款新的自动化测试工具,效果类似 selenium,但是这个模块年轻。模块名称为 playwright-pytho
- 前言Tkinter是python内置的标准GUI库,基于Tkinter实现了简易人员管理系统,所用数据库为Mongodb代码时间宝贵!直接上
- 前言最近看到一个有意思的机器学习项目——GFPGAN,他可以将模糊的人脸照片恢复清晰。开源项目的Github地址:https://githu
- vi /etc/freetds/freetds.conf [global]# TDS protocol versiontds version
- 一、开始之前必须安装itchat库pip install itchat(使用pip必须在电脑的环境变量中添加Python的路径)或 cond
- 前言上篇文章讲的进阶一些的PHP特性不知道大家吸收的怎么样了,今天作为本PHP特性函数的最后一篇,我也会重点介绍一些有趣的PHP特性以及利用
- 第一种import win32clipboardimport time#速度快 容易出错class niubi(): def lihai(s
- 我就废话不多说了,大家还是直接看代码吧~#!/usr/bin/env python# encoding: utf-8''
- 不得不说python的上手非常简单。在网上找了一下,大都是python2的帖子,于是随手写了个python3的。代码非常简单就不解释了,直接
- 实例如下所示:tcode={}transcode={}def GetTcode():#从文本中获取英文对应的故障码,并保存在tcode字典(
- 有时候需要比较大的计算量,这个时候Python的效率就很让人捉急了,此时可以考虑使用numba 进行加速,效果提升明显~(numba 安装貌
- python提取照片坐标信息的代码如下所示:from PIL import Imagefrom PIL.ExifTags import TA
- 因为最近公司有python项目维护,所以把python的基础入门的书整理一遍,因为有些忘记了,同时在看<<python编程>