Pytorch中实现CPU和GPU之间的切换的两种方法
作者:三个刺客 发布时间:2021-08-21 07:24:18
标签:Pytorch,CPU,GPU,切换
如何在pytorch中指定CPU和GPU进行训练,以及cpu和gpu之间切换
由CPU切换到GPU,要修改的几个地方:
网络模型、损失函数、数据(输入,标注)
# 创建网络模型
tudui = Tudui()
if torch.cuda.is_available():
tudui = tudui.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss_fn = loss_fn.cuda()
# 数据输入 包括训练和测试的代码,二者都需要添加此代码
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
方法一:.to(device)
1.不知道电脑GPU可不可用时:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )
a.to(device)
第一行代码的意思是判断电脑GPU可不可用,如果可用的话device就采用cuda()即调用GPU,不可用的话就采用cpu()即调用CPU。
第二行代码的意思就是把变量放到对应的device上(当然如果你用的是CPU的话就不用这一步了,因为变量默认是存在CPU上的,调用GPU的话要先把变量放到GPU上跑,跑完之后再调回CPU上)
2.指定GPU时
# 定义训练的设备
device = torch.device("cuda:0")
# 网络模型创建
tudui = Tudui()
tudui = tudui.to(device)
# 损失函数
loss_fn = nn.CrossEntropyLoss()
loss_fn = loss_fn.to(device)
# 训练步骤开始
tudui.train()
for data in train_dataloader:
imgs, targets=data
imgs = imgs.to(device)
targets = targets.to(device)
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
# 测试步骤开始
tudui.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets=data
imgs = imgs.to(device)
targets = targets.to(device)
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
total_test_loss = total_test_loss + loss.item()
accuracy = (outputs.argmax(1)==targets).sum()
total_accuracy = total_accuracy + accuracy
3.指定cpu时:
device = torch.device('cpu')
方法二:
1、需要修改的
# 三种常见的写法
device = torch.device('cuda')
device = torch.device('cuda: 0')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
2、代码
# 创建模型
tudui = Tudui()
if torch.cuda.is_available():
tudui = tudui.cuda()
# 损失函数
loss_fn = nn.CrossEntropyLoss()
if torch.cuda.is_available():
loss_fn = loss_fn.cuda()
# 训练步骤开始
tudui.train()
for data in train_dataloader:
imgs, targets=data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
# 测试步骤开始
tudui.eval()
total_test_loss = 0
total_accuracy = 0
with torch.no_grad():
for data in test_dataloader:
imgs, targets=data
if torch.cuda.is_available():
imgs = imgs.cuda()
targets = targets.cuda()
outputs = tudui(imgs)
loss = loss_fn(outputs, targets)
total_test_loss = total_test_loss + loss.item()
accuracy = (outputs.argmax(1)==targets).sum()
total_accuracy = total_accuracy + accuracy
总结:
推荐方法一,如果自己电脑是只有CPU,可以推荐使用云端服务器,比如PaddlePaddle,Google colab,这些服务器由每周免费八个小时的使用时间,可供我们基本的需求。
来源:https://blog.csdn.net/mxh3600/article/details/124460988


猜你喜欢
- asp+access用户登录代码,loginnew.asp网面包含了登录框及验证用户的代码an.mdb数据库名fd表名y_username用
- 本文实例为大家分享了python3实现qq邮箱登陆并发送邮件功能的具体代码,供大家参考,具体内容如下基于selenium,使用chrome浏
- PDOStatement::bindParamPDOStatement::bindParam — 绑定一个参数到指定的变量名(PHP 5 &
- Rand()函数是系统自带的获取随机数的函数,可以直接运行select rand() 获取0~1之间的float型的数字。如果想要获取0~1
- 这篇文章所说的视觉元素是指:在一个网站中除去内容(文本、图片、视频、音频等)之外的一些元素。比如图标,背景色,以及背景图案。视觉元素的设计是
- 有些 MySQL 数据表中可能存在重复的记录,有些情况我们允许重复数据的存在,但有时候我们也需要删除这些重复的数据。本章节我们将为大家介绍如
- 运行环境:Windows 8.1Python:2.7.6在安装的时候,我使用的pip来进行安装,命令如下:pip install beaut
- 该章节我们来学习一下在 Python 中去创建并使用多进程的方法,通过学习该章节,我们将可以通过创建多个进程来帮助我们提高脚本执行的效率。可
- 前言昨天,因为项目需求要添加表的更新接口,来存储预测模型训练的数据,所以自己写了一段代码实现了该功能,在开始之前,给大家分享python 操
- 很简单的教程,献给喜欢SEO的朋友们。把article.asp?logID=26 替换成article.asp?/a
- 发现问题Python中的urllib模块用来处理url相关的操作,unquote方法对应javascript中的urldecode方法,它对
- 首先创建公用js在static中创建js—>utils.jsutils.js内容如下:export default { install
- 主库执行CREATE DATABASE test CHARACTER SET utf8 COLLATE utf8_general_ci;us
- 有这样一个文本文件,内容有多行如下,数量不定。Lif(__amscript_cd("www.jb51.net")){__
- 这篇文章主要介绍了python代码如何实现余弦相似性计算,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 一、操作步骤导入:import flask,json实例化:api = flask.Flask(name)定义接口访问路径及访问方式:@ap
- 什么是锁锁的本质,就是一种资源,是由操作系统维护的一种专门用于同步的资源比如说互斥锁,说白了就是一种互斥的资源。只能有一个进程(线程)占有。
- pyecharts 是一个用于生成 Echarts 图表的类库。Echarts 是百度开源的一个数据可视化 JS 库。用 Echarts 生
- 导语电脑桌面文件太多查找起来比较花费时间,并且凌乱的电脑桌面也会影响工作心情,于是利用python根据时间自动建立当日文件夹,这样就可以把桌
- 最近在作图时需要将输出的图片紧密排布,还要去掉坐标轴,同时设置输出图片大小。要让程序自动将图表保存到文件中,代码为:plt.savefig(