Pytorch 多块GPU的使用详解
作者:R.X.Geng 发布时间:2021-01-21 09:19:09
标签:Pytorch,GPU
注:本文针对单个服务器上多块GPU的使用,不是多服务器多GPU的使用。
在一些实验中,由于Batch_size的限制或者希望提高训练速度等原因,我们需要使用多块GPU。本文针对Pytorch中多块GPU的使用进行说明。
1. 设置需要使用的GPU编号
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,4"
ids = [0,1]
比如我们需要使用第0和第4块GPU,只用上述三行代码即可。
其中第二行指程序只能看到第1块和第4块GPU;
第三行的0即为第二行中编号为0的GPU;1即为编号为4的GPU。
2.更改网络,可以理解为将网络放入GPU
class CNN(nn.Module):
def __init__(self):
super(CNN,self).__init__()
self.conv1 = nn.Sequential(
......
)
......
self.out = nn.Linear(Liner_input,2)
......
def forward(self,x):
x = self.conv1(x)
......
output = self.out(x)
return output,x
cnn = CNN()
# 更改,.cuda()表示将本存储到CPU的网络及其参数存储到GPU!
cnn.cuda()
3. 更改输出数据(如向量/矩阵/张量):
for epoch in range(EPOCH):
epoch_loss = 0.
for i, data in enumerate(train_loader2):
image = data['image'] # data是字典,我们需要改的是其中的image
#############更改!!!##################
image = Variable(image).float().cuda()
############################################
label = inputs['label']
#############更改!!!##################
label = Variable(label).type(torch.LongTensor).cuda()
############################################
label = label.resize(BATCH_SIZE)
output = cnn(image)[0]
loss = loss_func(output, label) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step()
... ...
4. 更改其他CPU与GPU冲突的地方
有些函数必要在GPU上完成,例如将Tensor转换为Numpy,就要使用data.cpu().numpy(),其中data是GPU上的Tensor。
若直接使用data.numpy()则会报错。除此之外,plot等也需要在CPU中完成。如果不是很清楚哪里要改的话可以先不改,等到程序报错了,再哪里错了改哪里,效率会更高。例如:
... ...
#################################################
pred_y = torch.max(test_train_output, 1)[1].data.cpu().numpy()
accuracy = float((pred_y == label.cpu().numpy()).astype(int).sum()) / float(len(label.cpu().numpy()))
假如不加.cpu()便会报错,此时再改即可。
5. 更改前向传播函数,从而使用多块GPU
以VGG为例:
class VGG(nn.Module):
def __init__(self, features, num_classes=2, init_weights=True):
super(VGG, self).__init__()
... ...
def forward(self, x):
#x = self.features(x)
#################Multi GPUS#############################
x = nn.parallel.data_parallel(self.features,x,ids)
x = x.view(x.size(0), -1)
# x = self.classifier(x)
x = nn.parallel.data_parallel(self.classifier,x,ids)
return x
... ...
然后就可以看运行结果啦,nvidia-smi查看GPU使用情况:
可以看到0和4都被使用啦
来源:https://blog.csdn.net/qazwsxrx/article/details/89672578
0
投稿
猜你喜欢
- 前言:多线程简单理解就是:一个CPU,也就是单核,将时间切成一片一片的,CPU轮转着去处理一件一件的事情,到了规定的时间片就处理下一件事情。
- 一、开发环境集成开发工具:jupyter notebook 6.5.2集成开发环境:Python 3.10.6第三方库:tensorflow
- 自从jQuery搞出特性侦探这东东,西方从来没有如此狂热研究浏览器。在以前javascript与DOM遍地是bug,美工主宰前端的年代,人们
- 本文实例讲述了vue动态组件和v-once指令。分享给大家供大家参考,具体如下:点击按钮时,自动切换两个组件<component :i
- 如果是django2.0 必须下载xadmin2.0 不然很多地方不兼容xadmin2.0下载地址https://github.com/ss
- Python 的元组与列表类似,不同之处在于元组的元素不能修改。元组使用小括号,列表使用方括号。元组创建很简单,只需要在括号中添加元素,并使
- 上一课:ACCESS入门教程:初识Access 2000窗口接口简介 通过上一课的学习,你是否感觉Access的窗口和接口还有点搞不清楚,对
- 前言简单的爬虫只有一个进程、一个线程,因此称为单线程爬虫。单线程爬虫每次只访问一个页面,不能充分利用计算机的网络带宽。一个页面最多也就几百K
- v-model指令 所谓的“指令”其实就是扩展了HTML标签功能(属性)。先来一个组件,不用vue-model,正常父子通信<!--
- 最近在看《Effective Python》,里面提到判断字符串或者集合是否为空的原则,原文如下:Don't check for e
- 本文将展示一个开源JavaScript库,该脚本库给AJAX应用程序带来了书签和后退按钮支持。在学习完这个教程后,开发人员将能够获得对一个A
- element-ui自带的图标库还是不够全,还是需要需要引入第三方icon,自己在用的时候一直有些问题,参考了些教程,详细地记录补充下对于我
- 在 Go 语言中切片是使用非常频繁的一种聚合类型,它代表变长的序列,底层引用一个数组对象。一个切片由三个部分构成:指针、长度和容量。指针指向
- 本文实例讲述了Python基于lxml模块解析html获取页面内所有叶子节点xpath路径功能。分享给大家供大家参考,具体如下:因为需要使用
- JavaScript是一门OOP,而有些人说,JavaScript是基于对象的。1) 如何创建对象:1. 使用constructor,例如:
- 前言如果你的 Python 程序程序有大量的 import,而且启动非常慢,那么你应该尝试懒导入,本文分享一种实现惰性导入的一种方法。虽然P
- 因为我们现在的前端框架做性能优化,为了找到各个组件及框架的具体解析耗时,需要在框架中嵌入一个耗时测试工具,性能测试跟不同的计算机硬件配置有很
- 本文实例为大家分享了js实现选项卡效果的具体代码,供大家参考,具体内容如下<!DOCTYPE html><html>
- 你知道吗?实际上Python早在20世纪90年代初就已经诞生,可是火爆时间却并不长,就小编本人来说,也是前几年才了解到它。据统计,目前Pyt
- 对于每个程序开发者来说,调试几乎是必备技能。代码写到一半卡住了,不知道这个函数执行完的返回结果是怎样的?调试一下看看代码运行到一半报错了,什