PyTorch深度学习模型的保存和加载流程详解
作者:软耳朵DONG 发布时间:2023-07-10 04:58:33
标签:PyTorch,模型的保存,模型的加载
一、模型参数的保存和加载
torch.save(module.state_dict(), path)
:使用module.state_dict()
函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path
所指定的文件存放路径(常用文件格式为.pt
、.pth
或.pkl
)。torch.nn.Module.load_state_dict(state_dict)
:从state_dict
中加载参数和缓冲区到Module
及其子类中 。torch.nn.Module.state_dict()
函数返回python
中的一个OrderedDict
类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict
中,例如:卷积层、线性层等。Python
中的字典类以“键:值
”方式存取数据,OrderedDict
是它的一个子类,实现了对字典对象中元素的排序(OrderedDict
根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict
字典对象会被当做是两个不同的对象。示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 获取state_dict
state_dict = net.state_dict()
# 字典的遍历默认是遍历key,所以param_tensor实际上是键值
for param_tensor in state_dict:
print(param_tensor,':\n',state_dict[param_tensor])
# 保存模型参数
torch.save(state_dict,"net_params.pth")
# 通过加载state_dict获取模型参数
net.load_state_dict(state_dict)
输出:
二、完整模型的保存和加载
torch.save(module, path)
:将训练完的整个网络模型module
保存到path
所指定的文件存放路径(常用文件格式为.pt
或.pth
)。torch.load(path)
:加载保存到path
中的整个神经网络模型。示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整个网络
torch.save(net,"net.pth")
# 加载网络
net = torch.load("net.pth")
来源:https://blog.csdn.net/m0_52650517/article/details/120836999


猜你喜欢
- 前言采集教务系统成绩单是一个非常有意义的项目。在现代教育中,教务系统已经成为了学校管理和教学工作的重要组成部分。然而,由于各种原因,教务系统
- 我们在使用ASP 内置的ADO组件进行数据库编程时,通常是在脚本的开头打开一个连接,并在脚本的最后关闭它,但是就较大脚本而言,在多数情况下连
- 使用ghost.py 通过搜搜 的微信搜索来爬取微信公共账号的信息# -*- coding: utf-8 -*-import sysrelo
- 本文实例讲述了PHP中substr_count()函数获取子字符串出现次数的方法。分享给大家供大家参考,具体如下:PHP中的substr_c
- 越来越多的网站在logo中添加叶子元素,而此类logo又常常使用绿色,这可以给人希望、清新、健康的感觉,从而很容易被接受和认可。今天我们又收
- 如题,先上效果图:主要分为两大步骤使用python语句,通过百度地图API,对已知的地名抓取经纬度使用百度地图API官网的html例程,修改
- 本文实例讲述了Python实现对一个函数应用多个装饰器的方法。分享给大家供大家参考,具体如下:下面的例子展示了对一个函数应用多个装饰器,可以
- 一、自定义分页1、基础版自定义分页data = []for i in range(1, 302): tmp = {"i
- 要求安装:1.Python2.7z解压软件backup_2.py# Filename: backup_2.py'''
- 本文实例讲述了Python3.5 Pandas模块之DataFrame用法。分享给大家供大家参考,具体如下:1、DataFrame的创建(1
- pygame.transform 模块允许您对加载、创建后的图像进行一系列操作,比如调整图像大小、旋转图片等操作,常用方法如下所示:下面看一
- aspjpeg组件实现加水印函数的调用方法: <%printwater "/images/水印图片.gif",&q
- 创建py文件总是为txt格式问题记录写代码过程中创建.py文件时,一直正常,但创建名称为train.py文件时总是为txt格式,即使选择了p
- CSS选择器目前,除了官方文档之外,市面上及网络详细介绍BeautifulSoup使用的技术书籍和博客软文并不多,而在这仅有的资料中介绍CS
- # 半夜撸代码 正在一顿操作猛如虎的时候,发现删了其中一张表的某条记录,结果发现其他表跟这个字段的关联的也都被删除,我已经写了d
- 下策——查询出结果后将时间排序后取第一条select * from a where create_time<="2017-0
- 总结类的定义很久以前,语言都是面向过程的,经过计算机科学家的探索,出现了面向对象。面向对象可以解释生活中很多东西。比如人,人就是个对象,有参
- 在python中,通过如下两个模块可以实现邮件的自动化操作smtplibemailsmtplib模块是对SMTP协议的封装,用于发送邮件;e
- py 写东西快但是java 生态广比如大数据 py 虽然好 但是利用不到java的整个的生态的代码scala 虽然也好但是毕竟 有些库 需要
- pandas.DataFrame行名(index)和列名(columns)的修改方法如下。rename()任意的行名(index)和列名(c