PyTorch深度学习模型的保存和加载流程详解
作者:软耳朵DONG 发布时间:2023-07-10 04:58:33
标签:PyTorch,模型的保存,模型的加载
一、模型参数的保存和加载
torch.save(module.state_dict(), path)
:使用module.state_dict()
函数获取各层已经训练好的参数和缓冲区,然后将参数和缓冲区保存到path
所指定的文件存放路径(常用文件格式为.pt
、.pth
或.pkl
)。torch.nn.Module.load_state_dict(state_dict)
:从state_dict
中加载参数和缓冲区到Module
及其子类中 。torch.nn.Module.state_dict()
函数返回python
中的一个OrderedDict
类型字典对象,该对象将每一层与它的对应参数和缓冲区建立映射关系,字典的键值是参数或缓冲区的名称。只有那些参数可以训练的层才会被保存到OrderedDict
中,例如:卷积层、线性层等。Python
中的字典类以“键:值
”方式存取数据,OrderedDict
是它的一个子类,实现了对字典对象中元素的排序(OrderedDict
根据放入元素的先后顺序进行排序)。由于进行了排序,所以顺序不同的两个OrderedDict
字典对象会被当做是两个不同的对象。示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 获取state_dict
state_dict = net.state_dict()
# 字典的遍历默认是遍历key,所以param_tensor实际上是键值
for param_tensor in state_dict:
print(param_tensor,':\n',state_dict[param_tensor])
# 保存模型参数
torch.save(state_dict,"net_params.pth")
# 通过加载state_dict获取模型参数
net.load_state_dict(state_dict)
输出:
二、完整模型的保存和加载
torch.save(module, path)
:将训练完的整个网络模型module
保存到path
所指定的文件存放路径(常用文件格式为.pt
或.pth
)。torch.load(path)
:加载保存到path
中的整个神经网络模型。示例:
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 2, 3)
self.pool1 = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
return x
# 初始化网络
net = Net()
net.conv1.weight[0].detach().fill_(1)
net.conv1.weight[1].detach().fill_(2)
net.conv1.bias.data.detach().zero_()
# 保存整个网络
torch.save(net,"net.pth")
# 加载网络
net = torch.load("net.pth")
来源:https://blog.csdn.net/m0_52650517/article/details/120836999
0
投稿
猜你喜欢
- 最近群里好多人讨论oracle安全问题,今天找了些资料学习了下 获取Oracle当前会话的一些属性 (对于sql注射的环境判断很有用哦) S
- 年前帮manager 招GUI设计实习生 (PS. 这个实习生职位依然open,欢迎有兴趣的同学来投,邮箱jj.ying [at] hp.c
- 在国内外大中型数据库管理系统中,把ORACLE作为数据库管理平台的用户比较多。RACLE 不论是数据库管理能力还是安全性都是无可非
- 用户登录验证脚本,Chkpwd.asp<% '=======用户登录验证脚本======= '如果尚未定义Passed
- 如何保持数据库中原有格式不变:这些问题在论坛里面几乎天天有人问~!其实当在输入信息,然后提交信息的时候,所有内容的格式是没有变的。只是在当提
- 小贤是一条可爱的小狗(Dog), 它的叫声很好听(wow), 每次看到主人的时候就会乖乖叫一声(yelp).从这段描述可以得到以下对象:fu
- 有使用过VS2005开发工具的朋友或者其他语句如js中都有Try catch 语句块,那么在mysql中是否能有SQLserver的@@er
- 代码如下: <% '屏蔽主流的下载工具 Dimxurl,xtool '获取浏览器AGENT xurl=lcase(Re
- Microsoft SQL Server 2000 能提供超大型系统所需的数据库服务。大型服务器可能有成千上万的用户同时连接到 SQL Se
- 假设要生成一千万个随机数,常规的做法如下:var numbers = [];for (var&nbs
- 一种很常见的写法: document.write('<scr'+'ipt src=&quo
- ASP里两种常用的生成文件的方式是:利用ADODB.Stream生成文件和利用Scripting.FileSystemObject生成文件1
- 多表操作 在一个数据库中,可能存在多个表,这些表都是相互关联的。我
- CSS圆角的现实一直是大家所热衷的话题,我们进行CSS布局一直强调语义,强调文档的结构。圆角作为页面的外面表现,应该分离到CSS文件中,可以
- 描述Python strip() 方法用于移除字符串头尾指定的字符(默认为空格)。语法strip()方法语法:str.strip([char
- Delphi连接MySQL真麻烦,研究了一天,从网上找了无数文章,下载了无数插件都没解决。最后返璞归真,老老实实用ADO来连接,发现也不是很
- 沟通的时候,一般我不主动说自己是做用户体验设计,也不说做以用户为中心的设计,包括UED, UCD。这种专业名词传达的太虚,你也许是名用户体验
- 1. ASCII 返回与指定的字符对应的十进制数; SQL> select ascii(A) A,ascii(a) a,as
- 内容摘要:本文介绍了通过获取访问者的IP地址来统计在线人数的方法,本文只是给出了实现统计在线人数的方法思路,具体代码的实现过程还得自己动手(
- 思考一个问题:怎么实现在第一次检索的基础上进行二次检索?通常,我们的做法是第一次检索时保存检索条件,在第二次行检索时组合两次检索条件对数据库