解决pytorch 的state_dict()拷贝问题
作者:Luke_Ye 发布时间:2022-10-05 22:03:57
先说结论
model.state_dict()
是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict())
,或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
其实看了如下代码的输出应该就懂了
import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
定义了一个类BaseNet并实例化该类:
net=BaseNet()
保存net时报错 object has no attribute 'state_dict'
torch.save(net.state_dict(), models_dir)
原因是定义类的时候不是继承nn.Module类,比如:
class BaseNet(object):
def __init__(self):
把类定义改为
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://www.cnblogs.com/LukeStepByStep/p/11248361.html


猜你喜欢
- mysql_result定义和用法mysql_result() 函数返回结果集中一个字段的值。mysql_result() 返回 MySQL
- 引入:Python中有个logging模块可以完成相关信息的记录,在debug时用它往往事半功倍一、日志级别(从低到高):DEBUG :详细
- Python中的random函数random模块提供生成伪随机数的函数,在使用时需要导入random模块1. random.random()
- 目录一、索引基础1. 索引的类型1.1 B-Tree 索引1.2 哈希索引1.3 空间数据索引(R-Tree)1.4 全文索引二、索引的优缺
- 最近想实现PHP实现短信验证的效果,做PC网站的时候,可以通过注册用户需要使用短信验证的功能,或者找回密码,以及验证用户的信息等等功能,发现
- 前言很多开发同学对SQL优化如数家珍,却对MySQL架构一知半解。岂不是只见树叶,不见森林,终将陷入细节中不能自拔。今天就一块学习MySQL
- Create trigger tri_wk_CSVHead_History on wk_CSVHead_History --声明一个tri_
- MySQL内外连接表的连接分为内连接和外连接。内连接内连接内连接的SQL如下:SELECT ... FROM t1 INNER JOIN t
- IT界的每个人都应该知道终端(Terminal)的基本知识,数据科学家也不例外。有时,终端是你的全部,尤其是在将模型和数据管道部署到远程机器
- 在asp里通过以下两个函数实现javascript里的escape函数和unescape函数加密功能。在ajax post或get时内容存在
- matplotlib是python最著名的绘图库,它提供了一整套和matlab相似的命令API,十分适合交互式地进行制图。而且也
- 大多数网站维护都采用“多人协作,共同管理”方式。某个人负责一个(或者多个)栏目,他只能对他负责的栏目进
- 1.创建空字典>>> dic = {}>>> type(dic)<type 'dict
- forEaches5出来的方法,这是我在react中用的最多的遍历方法之一,用法如下:models.forEach(model =>
- 一、遇到的问题在向数据库中存入汉字时遇到这样的问题:Cause: java.sql.SQLException: Incorrect stri
- 今天遇到下图这种问题,文字过长,显示不全。折腾了老半天,在网上搜了半天也找不到解决方案。于是问了下同事,同事提到了<optgroup&
- 推荐阅读:Oracle读取excel数据oracle导出excel(非csv)的方法有两种,1、使用sqlplus spool,2、使用包体
- 一、在访客的内心深处做导航我讨厌迷失,不管是在道路上或是在线网络上。猜想一下?您的访客也是这样的。就像我们期望看到的道路上的路标一样,来帮助
- 1、python教程基于 python3.10 的持续解读,旨在快速回忆加深理解,节约自己的时间成本1.1 概述python 是一门易于学习
- 作者:做梦的人(小姐姐)出处:https://www.cnblogs.com/chongyou/python读取yaml文件使用,有两种方式