将pytorch的网络等转移到cuda
作者:aleien1 发布时间:2023-08-10 08:33:46
标签:pytorch,网络,cuda
神经网络一般用GPU来跑,我们的神经网络框架一般也都安装的GPU版本,本文就简单记录一下GPU使用的编写。
GPU的设置不在model,而是在Train的初始化上。
第一步是查看是否可以使用GPU
self.GPU_IN_USE = torch.cuda.is_available()
就是返回这个可不可以用GPU的函数,当你的pytorch是cpu版本的时候,他就会返回False。
然后是:
self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')
torch.device是代表将torch.tensor分配到哪个设备的函数
接着是,我看到了一篇文章,原来就是将网络啊、数据啊、随机种子啊、损失函数啊、等等等等直接转移到CUDA上就好了!
于是下面就好理解多了:
转移模型:
self.model = Net(num_channels=1, upscale_factor=self.upscale_factor, base_channel=64, num_residuals=4).to(self.device)
设置cuda的随机种子:
torch.cuda.manual_seed(self.seed)
转移损失函数:
self.criterion.cuda()
转移数据:
data, target = data.to(self.device), target.to(self.device)
pytorch 网络定义参数的后面无法加.cuda()
pytorch定义网络__init__()的时候,参数不能加“cuda()", 不然参数不包含在state_dict()中,比如下面这种写法是错误的
self.W1 = nn.Parameter(torch.FloatTensor(3,3), requires_grad=True).cuda()
应该去掉".cuda()"
self.W1 = nn.Parameter(torch.FloatTensor(3,3), requires_grad=True)
来源:https://blog.csdn.net/weixin_42128941/article/details/103048866


猜你喜欢
- 一. np.dot()1.同线性代数中矩阵乘法的定义。np.dot(A, B)表示:对二维矩阵,计算真正意义上的矩阵乘积。对于一
- 目录range函数的使用第一种创建方式第二种创建方式第三种创建方式判断指定的数有没有在当前序列中循环结构总结range函数的使用作为循环遍历
- 本文实例讲述了python使用装饰器和线程限制函数执行时间的方法。分享给大家供大家参考。具体分析如下:很多时候函数内部包含了一些不可预知的事
- 本文实例讲述了Thinkphp 框架基础之源码获取、环境要求与目录结构。分享给大家供大家参考,具体如下:获取ThinkPHP获取ThinkP
- 安装好所需要的插件和包:python、django、pip等版本如下:采用Django REST框架3.01、在python文件夹下D:\p
- MySQL是一个开源的关系型数据库管理系统,支持多种操作语言,其中最基础、最常用的命令之一就是SELECT语句。在本篇文章中,这里将详细介绍
- 我们将研究一种判别式分类方法,其中直接学习评估 g(x)所需的 w 参数。我们将使用感知器学习算法。感知器学习算法很容易实现,但为了节省时间
- innodb_flush_method的几个典型取值fsync: InnoDB uses the fsync() system call t
- 前言最近工作中遇到一个需求,是根据用户连续记录天数来计算的,求出用户在一段时间内最大的连续记录时间,例如在 2016-01-01 和 201
- Doug Bowman,Google的Visual Design Lead离职了,一封带有感 * 彩的离职信惹发了大家不少的讨论。甚至还有人用
- 使用Python内置函数:bin()、oct()、int()、hex()可实现进制转换。 先看Python官方文档中对这几个内置函数的描述:
- 页面上有些重要内容需要提醒客户,可采用的方法有很多。提醒用户关注某一区域(div),可以给该div加上边框闪烁的效果,达到吸引用户眼球的效果
- 一、python-yml文件读写使用库 :import yaml安装:pip install pyyaml示例:文件config2.ymlg
- 本文实例讲述了MySQL触发器简单用法。分享给大家供大家参考,具体如下:mysql触发器和存储过程一样,是嵌入到mysql的一段程序,触发器
- python replace函数替换无效问题str = "hello,china!"str.replace("
- 如题,首先当然是要打开京东的手机页面因为要获取不同页面的所有手机图片,所以我们要跳转到不同页面观察页面地址的规律,这里观察第二页页面由观察可
- 一、前言刚刚学了一些python文件读写的内容,先跑过来整活了。顺便复习一下之前学的东西。import timedoc_local='
- 相信用过thinkphp的用户都知道thinkphp的模型可以完成很多辅助功能,比如自动验证、自动完成等,今天在开发中遇到自动完成中需要获取
- .Net新手通常容易把属性(Property)跟特性(Attribute)搞混,其实这是两种不同的东西属性指的类中封装的数据字段;而特性是对
- 多线程锁lock=threading.Lock()使用疑问多线程任务是同时执行的,如果我们需要先执行线程a,再执行线程b,需要怎么办呢?解决