pytorch 状态字典:state_dict使用详解
作者:wzg2016 发布时间:2023-01-16 11:42:52
pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
备注:
1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.
模态字典(state_dict)的保存(model是一个网络结构类的对象)
1.1)仅保存学习到的参数,用以下命令
torch.save(model.state_dict(), PATH)
1.2)加载model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名
2.1)保存整个model的状态,用以下命令
torch.save(model,PATH)
2.2)加载整个model的状态,用以下命令:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项
如何仅加载某一层的训练的到的参数(某一层的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)
for param in list(model.pretrained.parameters()):
param.requires_grad = False
注意: requires_grad的操作对象是tensor.
疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:经测试,不可以.model.conv1 没有requires_grad属性.
全部测试代码:
#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# initial model
model = TheModelClass()
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)
来源:https://blog.csdn.net/Strive_For_Future/article/details/83240081
猜你喜欢
- CSS处理斜角导航条的一个例子,这个是写着测试用的。暂没有实际的应用。斜角处理比较麻烦,主要有两个地方。1、图片处理。2、负数的理解。这两个
- 转:coolcode.cn前几天写了一篇在任意字符集下正常显示网页的方法,里面介绍的很简单,就是把前128个字符以外的字符集都用
- python是个很好玩的东西?好吧我随口说的,反正因为各种原因(其实到底是啥我也不知道),简单的学习了下python,然后写了一个上传文件上
- 前言二维码现在是随处度可以看到,买东西,支付,添加好友只要你扫一扫就能完成整个工作,简单且方便。所以利用这个新春佳节做一个带着新春祝福的二维
- 1. 安装 Sublime Text 3虽然现在的 Sublime 3 还处于 beta 阶段, 但已经非常稳定了, 而且速度比 Subli
- 本文实例讲述了PHP会话控制技巧。分享给大家供大家参考,具体如下:Demo1.php<form method="get&qu
- 网页上的图片如果设置了alt属性,当鼠标移经时就会有tooltip出现,但是只能显示一行文本,有时需要多行文本,乃至图片来显示图片、链接或者
- Google Chrome 的发布,使我们更加的注重基于 WebKit 核心的浏览器的表现情况,但我们很多时候“不小心”就会出现
- 代码如下:Function splitx(strs1 As String, strs2 A
- 一、安装 wordcloudpip install wordcloud二、加载包、设置路径import osfrom wordcloud i
- 这是asp利用dictionary创建二维数组的例子,这样做的优点是:1、数组下标可以是字符串2、长度不是固定的<'% ’==
- 每一字符串字符文字有一个字符集和一个校对规则,它不能为空。一个字符串文字可能有一个可选的字符集引介词和COLLATE子句:[_charset
- match()方法用于从字符串中查找指定的值本方法类似于indexOf()和lastindexOf(),不同的是它返回的是指定的值,而不是指
- 本文实例讲述了Python实现对一个函数应用多个装饰器的方法。分享给大家供大家参考,具体如下:下面的例子展示了对一个函数应用多个装饰器,可以
- 页面重构需要考虑的一个重点是XHTML代码语义化,就算是在无任何CSS样式修饰的情况下也能给他人在阅读时带来便利,甚至可以夸张点说在搜索引擎
- django中的超链接,在template中可以用{% url 'app_name:url_name' param%}其中a
- 本文实例讲述了Python实现求解一元二次方程的方法。分享给大家供大家参考,具体如下:1. 引入math包2. 定义返回的对象3. 判断b*
- 在这里奉上源代码,没有做样式处理,不过功能是可以的,希望大家可以和我交流交流!<html> <head>&
- 0x00 环境系统环境:win10编写工具:JetBrains PyCharm Community Edition 2017.1.2 x64
- 问题描述我在用Keras的Embedding层做nlp相关的实现时,发现了一个神奇的问题,先上代码:a = Input(shape=[15]