pytorch 状态字典:state_dict使用详解
作者:wzg2016 发布时间:2023-01-16 11:42:52
pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)
(注意,只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等)
优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)
备注:
1) state_dict是在定义了model或optimizer之后pytorch自动生成的,可以直接调用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自动具备的函数,可以直接调用
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因为,只有在执行该命令后,"dropout层"及"batch normalization层"才会进入 evalution 模态. 而在"训练(training)模态"与"评估(evalution)模态"下,这两层有不同的表现形式.
模态字典(state_dict)的保存(model是一个网络结构类的对象)
1.1)仅保存学习到的参数,用以下命令
torch.save(model.state_dict(), PATH)
1.2)加载model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
备注:model.load_state_dict的操作对象是 一个具体的对象,而不能是文件名
2.1)保存整个model的状态,用以下命令
torch.save(model,PATH)
2.2)加载整个model的状态,用以下命令:
# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()
state_dict 是一个python的字典格式,以字典的格式存储,然后以字典的格式被加载,而且只加载key匹配的项
如何仅加载某一层的训练的到的参数(某一层的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加载模型参数后,如何设置某层某参数的"是否需要训练"(param.requires_grad)
for param in list(model.pretrained.parameters()):
param.requires_grad = False
注意: requires_grad的操作对象是tensor.
疑问:能否直接对某个层直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:经测试,不可以.model.conv1 没有requires_grad属性.
全部测试代码:
#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16*5*5,120)
self.fc2 = nn.Linear(120,84)
self.fc3 = nn.Linear(84,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# initial model
model = TheModelClass()
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,'\t',model.state_dict()[param_tensor].size())
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
print(var_name,'\t',optimizer.state_dict()[var_name])
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 仅仅加载某一层的参数
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)
来源:https://blog.csdn.net/Strive_For_Future/article/details/83240081


猜你喜欢
- 一、yield关键字实现方式以yield关键字方式实现协程代码如下所示:def fun1(): yield 1 &
- 前言Python是一门实现数据可视化很好的语言,他们里面的很多库可以很好的画出图形,形象明了。今天我们就来说说:Pandas数据分析核心支持
- 前言目前,Golang 可以认为是服务器开发语言发展的趋势之一,特别是在流媒体服务器开发中,已经占有一席之地。很多音视频技术服务提供商也大多
- 这是一款简单,方便,功能齐全的分页类,可以根据自己的需要更改CSS样式文件以实现分页颜色的控制,利用p
- 本文实例讲述了python单例模式。分享给大家供大家参考。具体分析如下:__new__()在__init__()之前被调用,用于生成实例对象
- python解析网页,无出BeautifulSoup左右,此是序言安装BeautifulSoup4以后的安装需要用eazy_install,
- 本文实例讲述了python统计字符串中指定字符出现次数的方法。分享给大家供大家参考。具体如下:python统计字符串中指定字符出现的次数,例
- django是python语言快速实现web服务的大杀器,其开发效率可以非常的高!但因为秉承了语言的灵活性,django框架又太灵活,以至于
- 废话不多说啦,直接看代码吧!tf.concatt1 = [[1, 2, 3], [4, 5, 6]]t2 = [[7, 8, 9], [10
- 如何正确理解和使用Command、Connection和 Recordset三个对象?我知道它们都是连接数据库的“好手”,但在编程的具体应用
- 路由路由是指如何定义应用的端点(URIs)以及如何响应客户端的请求。路由是由一个 URI、HTTP 请求(GET、POST等)和若干个句柄组
- 继续练手,根据之前获取汽油价格的方式获取了金价,暂时没钱投资,看看而已#!/usr/bin/env python# -*- coding:
- PDO常用方法:PDO::query()主要用于有记录结果返回的操作(PDOStatement),特别是select操作。PDO::exec
- 需求背景在很多时候我们需要抽取视频的某一帧做一些分析或修改等;比如笔者需求就是判断一个人在该视频中出现的频率,以判断他是否是这段视频的主角;
- 在利用python进行flask等开发过程中经常需要配置虚拟环境以方便针对不同的项目需求配置不同的生产环境。在python3.3之前,需要利
- 前言在实际生产环境中,如果对mysql数据库的读和写都在一台数据库服务器中操作,无论是在安全性、高可用性,还是高并发等各个方面都是不能满足实
- 一、文章前言此文主要通过小程序实现对比人脸相似度,并返回相似度分值,可以基于分值判断是否为同一人。人脸登录、用户认证等场景都可以用到。二、具
- 目录前言线程安全锁的作用Lock() 同步锁基本介绍使用方式死锁现象with语句RLock() 递归锁基本介绍使用方式with语句Condi
- 背景写一个python脚本,实现简单的http服务器功能:1.浏览器中输入网站地址:172.20.52.163:200142.server接
- 背景最近在写一个echarts数据看板,要在一个页面中展示多张图表,所以留给每张图表的尺寸就很小。这也就使得图表x轴的刻度文字全部挤到一起了