PyTorch 多GPU下模型的保存与加载(踩坑笔记)
作者:叶罅 发布时间:2023-07-20 15:39:18
标签:PyTorch,GPU,模型
这几天在一机多卡的环境下,用pytorch训练模型,遇到很多问题。现总结一个实用的做实验方式:
多GPU下训练,创建模型代码通常如下:
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
model = MyModel(args)
if torch.cuda.is_available() and args.use_gpu:
model = torch.nn.DataParallel(model).cuda()
官方建议的模型保存方式,只保存参数:
torch.save(model.module.state_dict(), "model.pkl")
其实,这样很麻烦,我建议直接保存模型(参数+图):
torch.save(model, "model.pkl")
这样做很实用,特别是我们需要反复建模和调试的时候。这种情况下模型的加载很方便,因为模型的图已经和参数保存在一起,我们不需要根据不同的模型设置相应的超参,更换对应的网络结构,如下:
if not (args.pretrained_model_path is None):
print('load model from %s ...' % args.pretrained_model_path)
model = torch.load(args.pretrained_model_path)
print('success!')
但是需要注意,这种方式加载的是多GPU下模型。如果服务器环境变化不大,或者和训练时候是同一个GPU环境,就不会出现问题。
如果系统环境发生了变化,或者,我们只想加载模型参数,亦或是遇到下面的问题:
AttributeError: 'model' object has no attribute 'copy'
或者
AttributeError: 'DataParallel' object has no attribute 'copy'
或者
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found
这时候我们可以用下面的方式载入模型,先建立模型,然后加载参数。
os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda
# 建立模型
model = MyModel(args)
if torch.cuda.is_available() and args.use_gpu:
model = torch.nn.DataParallel(model).cuda()
if not (args.pretrained_model_path is None):
print('load model from %s ...' % args.pretrained_model_path)
# 获得模型参数
model_dict = torch.load(args.pretrained_model_path).module.state_dict()
# 载入参数
model.module.load_state_dict(model_dict)
print('success!')
来源:https://www.cnblogs.com/blog4ljy/p/11711173.html
0
投稿
猜你喜欢
- /** 2 * 检索数组元素(原型扩展或重载) 3 * @param {o} 被检索的元素值 4 * @type int 5 * @retu
- 去听了牛人 dbaron 的一个 Web Page Layout/Display in Mozilla 讲座( via )。讲的东西对我一个
- 前言人脸处理是人工智能中的一个热门话题,人脸处理可以使用计算机视觉算法从人脸中自动提取大量信息,例如身份、意图和情感;而目标跟踪试图估计目标
- 在 Google 搜索结果页面中,将其 Logo 图标右键另存为后可以发现,它并非单纯的
- 项目需求:浏览器中访问django后端某一条url(如:127.0.0.1:8080/get_book/),实时朝数据库中生成一千条数据并将
- 首先看一下目标的验证形态是什么样子的是一种通过验证推理的验证方式,用来防人机破解的确是很有效果,但是,But,这里面已经会有一些破绽,比如:
- 本文实例讲述了Python提示[Errno 32]Broken pipe导致线程crash错误解决方法。分享给大家供大家参考。具体方法如下:
- 一,封装封装是面向对象编程思想的重要特征之一。(一)什么是封装封装是一个抽象对象的过程,它容纳了对象的属性和行为实现细节,并以此对外提供公共
- 首先,我们先来看看,如果是人正常的行为,是如何获取网页内容的。(1)打开浏览器,输入URL,打开源网页(2)选取我们想要的内容,包括标题,作
- 最近开始学机器学习,学习分析垃圾邮件,其中有一部分是要求去除一段字符中的标点符号,查了一下,网上的大多很复杂例如这样import re te
- 由于工作中涉及到生日编辑资料编辑,然后自己改了一下代码:<html><head> <meta charset=
- 我们在flask的学习中,会难免遇到多对多表的查询,今天我也遇到了这个问题。那么我想了好久。也没有想到一个解决的办法,试了几种方法,可能是思
- 本文实例讲述了Python基于Logistic回归建模计算某银行在降低贷款拖欠率的数据。分享给大家供大家参考,具体如下:一、Logistic
- 写程序的人在编写由asp页面生成静态页面html的时候,如果同时生成大量页面,一定遇到过浏览器下方的进度条上显示着3%,6%,10%等缓慢增
- 本文描述通过统计分析出医院信息系统需分区的表,对需分区的表选择分区键,即找出包括在你的分区键中的列(表的属性),对大型数据的管理比较有意义,
- 你是不是在学习python的时候在使用虚拟机系统进行开发,来回切换很是不方便,那么今天给大家推荐一个pycharm强大的功能。接下来我们利用
- 对于软件来说,每一个新版本的推出都应该是一种进步,不可否认,阿里旺旺2008版相较于之前的版本的确是有很多的进步,但进步的同时却也有着比之前
- 大家知道,在js里encodeURIComponent 方法是一个比较常用的编码方法,但因工作需要,在asp里需用到此方法,查了好多资料,没
- 看了下函数本身的docgetattr(object, name[, default]) -> valueGet a named att
- break 语句Python break语句,就像在C语言中,打破了最小封闭for或while循环。break语句用来终止循环语句,即循环条