pytorch中Tensor.to(device)和model.to(device)的区别及说明
作者:康康同学97 发布时间:2021-10-20 05:26:06
标签:pytorch,Tensor.to,model.to,device
Tensor.to(device)和model.to(device)的区别
区别所在
使用GPU训练的时候,需要将Module对象和Tensor类型的数据送入到device。通常会使用 to.(device)。但是需要注意的是:
对于Tensor类型的数据,使用to.(device) 之后,需要接收返回值,返回值才是正确设置了device的Tensor。
对于Module对象,只用调用to.(device) 就可以将模型设置为指定的device。不必接收返回值。
来自pytorch官方文档的说明:
Tensor.to(device)
Module.to(device)
举例
# Module对象设置device的写法
model.to(device)
# Tensor类型的数据设置 device 的写法。
samples = samples.to(device)
pytorch学习笔记--to(device)用法
在学习深度学习的时候,我们写代码经常会见到类似的代码:
img = img.to(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
model = models.vgg16_bn(pretrained=True).to(device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
也可以先定义device:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
img = img.to(device)
这段代码到底有什么用呢?
这段代码的意思就是将所有最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行。
为什么要在GPU上做运算呢?
首先,在做高维特征运算的时候,采用GPU无疑是比用CPU效率更高,如果两个数据中一个加了.cuda()或者.to(device),而另外一个没有加,就会造成类型不匹配而报错。
tensor和numpy都是矩阵,前者能在GPU上运行,后者只能在CPU运行,所以要注意数据类型的转换。
.cuda()和.to(device)的效果一样吗?为什么后者更好?
两个方法都可以达到同样的效果,在pytorch中,即使是有GPU的机器,它也不会自动使用GPU,而是需要在程序中显示指定。调用model.cuda(),可以将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。
如果你有多个GPU
那么可以参考以下代码:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model()
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model,device_ids=[0,1,2])
model.to(device)
来源:https://blog.csdn.net/qq_38436266/article/details/117784900


猜你喜欢
- 一、什么是数字识别? 所谓的数字识别,就是使用算法自动识别出图片中的数字。具体的效果如下图所示:上图展示了算法的处理效果,算法能够自动的识
- 说到JavaScript中声明变量的几种方法也就是var、let、const了,let和const是es6中新增的命令。那么它们之间有什么区
- 以下方案皆为引用,仅供参考。方案一:1.先声明一下,这种解决方法适用于任何版本的永久破解启动不了的情况(包括:2019版本的)2.下面直接切
- 英文的文档在这里,详细全面,本文仅为自己的学习笔记,只是试图通过转述加深自己的学习,不详细不全面。由于浏览器之间的差异,所以在JS中监听事件
- 设计一个算法,将URL转换成5部分,分别是:schema、netloc、path、query_params、fragment。问题URL的中
- 本文实例讲述了Python使用当前时间、随机数产生一个唯一数字的方法。分享给大家供大家参考,具体如下:Python生成当前时间很简单,比Ja
- 本文主要关于python的正则表达式的符号与方法。findall: 找寻所有匹配,返回所有组合的列表search: 找寻第一个匹配并返回su
- 类似Java打包操作,若不想让人看到Python程序内部逻辑,也可将其转换为exe可执行文件首先自己写一个Python程序,如下:print
- Py2 时代,访问 MySQL 数据库的模块除了 PyMySQL 和 MySQL-python 之外,还有以速度见长的 Umysql,以及非
- 对于python开发用户而言,经常需要安装一些python的第三方库,但是第三方库的安装经常出错,以下给大家介绍一下python安装第三方库
- WaitGroup概念Go标准库提供了WaitGroup原语, 可以用它来等待一批 Goroutine 结束底层数据结构// A WaitG
- 本文实例讲述了python实现统计代码行数的方法。分享给大家供大家参考。具体实现方法如下:'''Author: li
- 最近入了一块树莓派,想让其实现摄像头的调用,因此写下此博客备忘一、树莓派网络的配置首先,对树莓派进行网络配置,否则就无法进行软件的安装我们知
- Go语言在进行文件操作的时候,可以有多种方法。最常见的比如直接对文件本身进行Read和Write; 除此之外,还可以使用bufio库的流式处
- 每个矿工将从先前创建的交易池中获取交易.要跟踪已挖掘的消息数量,我们必须创建一个全局变量 :last_transaction_index =
- 为什么要用python调用matlab?我自己的有些数据结构涉及到hash查找,在python中key是tuple形式,在matlab中支持
- 我们最近的项目中需要使用谷歌机器人验证,这个最主要的就是要有vpn,还需要有公司申请的google账号(自己申请的没用)用于商用的,利用这个
- 写了个多层感知器,用bp梯度下降更新,拟合正弦曲线,效果凑合。# -*- coding: utf-8 -*-import numpy as
- 本文实例讲述了mysql中left join设置条件在on与where时的用法区别。分享给大家供大家参考,具体如下:一、首先我们准备两张表来
- 单页应用程序是或多或少复杂的应用程序,它加载一个单一的HTML页面。每当用户与它们互动时,它们就会使用JavaScript动态地更新其内容。