pytorch中的model=model.to(device)使用说明
作者:Wanderer001 发布时间:2023-02-23 15:07:48
这代表将模型加载到指定设备上。
其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")则代表的使用GPU。
当我们指定了设备之后,就需要将模型加载到相应设备中,此时需要使用model=model.to(device),将模型加载到相应的设备中。
将由GPU保存的模型加载到CPU上。
将torch.load()函数中的map_location参数设置为torch.device('cpu')
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
将由GPU保存的模型加载到GPU上。确保对输入的tensors调用input = input.to(device)方法。
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
将由CPU保存的模型加载到GPU上。
确保对输入的tensors调用input = input.to(device)方法。map_location是将模型加载到GPU上,model.to(torch.device('cuda'))是将模型参数加载为CUDA的tensor。
最后保证使用.to(torch.device('cuda'))方法将需要使用的参数放入CUDA。
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
补充:pytorch中model.to(device)和map_location=device的区别
一、简介
在已训练并保存在CPU上的GPU上加载模型时,加载模型时经常由于训练和保存模型时设备不同出现读取模型时出现错误,在对跨设备的模型读取时候涉及到两个参数的使用,分别是model.to(device)和map_location=devicel两个参数,简介一下两者的不同。
将map_location函数中的参数设置 torch.load()为 cuda:device_id。这会将模型加载到给定的GPU设备。
调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量,无论在cpu上训练还是gpu上训练,保存的模型参数都是参数张量不是cuda张量,因此,cpu设备上不需要使用torch.to(torch.device("cpu"))。
二、实例
了解了两者代表的意义,以下介绍两者的使用。
1、保存在GPU上,在CPU上加载
保存:
torch.save(model.state_dict(), PATH)
加载:
device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
解释:
在使用GPU训练的CPU上加载模型时,请传递 torch.device('cpu')给map_location函数中的 torch.load()参数,使用map_location参数将张量下面的存储器动态地重新映射到CPU设备 。
2、保存在GPU上,在GPU上加载
保存:
torch.save(model.state_dict(), PATH)
加载:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
解释:
在GPU上训练并保存在GPU上的模型时,只需将初始化model模型转换为CUDA优化模型即可model.to(torch.device('cuda'))。
此外,请务必.to(torch.device('cuda'))在所有模型输入上使用该 功能来准备模型的数据。
请注意,调用my_tensor.to(device) 返回my_tensorGPU上的新副本。
它不会覆盖 my_tensor。
因此,请记住手动覆盖张量: my_tensor = my_tensor.to(torch.device('cuda'))
3、保存在CPU,在GPU上加载
保存:
torch.save(model.state_dict(), PATH)
加载:
device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
解释:
在已训练并保存在CPU上的GPU上加载模型时,请将map_location函数中的参数设置 torch.load()为 cuda:device_id。
这会将模型加载到给定的GPU设备。
接下来,请务必调用model.to(torch.device('cuda'))将模型的参数张量转换为CUDA张量。
最后,确保.to(torch.device('cuda'))在所有模型输入上使用该 函数来为CUDA优化模型准备数据。
请注意,调用 my_tensor.to(device)返回my_tensorGPU上的新副本。
它不会覆盖my_tensor。
因此,请记住手动覆盖张量:my_tensor = my_tensor.to(torch.device('cuda'))
来源:https://blog.csdn.net/weixin_36670529/article/details/104367696
猜你喜欢
- 在我们建立一个数据库时,并且想将分散在各处的不同类型的数据库分类汇总在这个新建的数据库中时,尤其是在进行数据检验、净化和转换时,将会面临很大
- udf_WeekDayName 代码如下:CREATE FUNCTION [dbo].[udf_WeekDayName] ( ) RETUR
- 前言项目中要实现多选,就想到用插件,选择了bootstrap-select。附上官网api链接,http://silviomoreto.gi
- 阅读上一篇:垂直栅格与渐进式行距(上) 新问题来也匆匆,去也“冲冲”。距上次发布垂直栅格与渐进式行距(上)发布,已经不知不觉过去了
- 1:readline()file = open("sample.txt") while 1: line =
- 说起来惭愧,总是犯一些小错误,纠结半天,这不应为一个分号的玩意折腾了好半天! 错误时在执行SQL语句的时候发出的,信息如下: Java代码
- 这是一套适用于JavaScript程序的编码规范。它基于Sun的Java程序编码规范。但进行了大幅度的修改, 因为JavaScript不是J
- 在支持FSO的情况下,可以显示本站内的所有ASP页面的代码适用于代码演示时在效果页面上直接显示该页面的代码而不用再对代码制作专门的页面使用方
- 函数重载的替代方法-伪重载,下面看一个具体的实例代码。<? php//函数重载的替代方法-伪重载////确实,在PHP中没有函数重载这
- 前言我们经常需要将大量数据保存起来以备后续使用,数据库是一个很好的解决方案。在众多数据库中,MySQL数据库算是入门比较简单、语法比较简单,
- 两个函数的原型为:np.identity(n, dtype=None)np.eye(N, M=None, k=0, dtype=<ty
- Python 格式化输出字符串(输出字符串+数字的几种方法)1. 介绍字符串格式化输出是python非常重要的基础语法。格式化输出:内容按照
- 一定要注重代码规范,按照平时的代码管理,可以将Python代码规范检测分为两种:静态本地检测:可以借助静态检查工具,比如:Flake8,Py
- 让Python提速超过40倍的神器:Cython人工智能最火的语言,自然是被誉为迄今为止最容易使用的代码之一的Python。Python代码
- 1. 基本环境安装 anaconda 环境, 由于国内登陆不了他的官网 https://www.continuum.io/downloads
- 推荐阅读:使用python检测主机存活端口及检查存活主机下面给大家分享使用python语言实现获取主机名根据端口杀死进程代码。ip=os.p
- HTML文件其实就是由一组尖括号构成的标签组织起来的,每一对尖括号形式一个标签,标签之间存在上下关系,形成标签树;XPath 使用路径表达式
- 实现思路是用深度遍历,对图片进行二值化处理,先找到一个黑色像素,然后对这个像素的周围8个像素进行判断,如果没有访问过,就保存起来,然后最后这
- 1.效果图:2.代码# 作用域 是 对象生效的区域(对象能被使用的区域)# 全局作用域在任意位置可生效# 局部作用域在函数内生效c = 20
- 前言本文将教你如何使用YOLOV3对象检测器、OpenCV和Python实现对图像和视频流的检测。用到的文件有yolov3.weights、