解决pytorch 的state_dict()拷贝问题
作者:Luke_Ye 发布时间:2022-10-05 22:03:57
先说结论
model.state_dict()
是浅拷贝,返回的参数仍然会随着网络的训练而变化。
应该使用deepcopy(model.state_dict())
,或将参数及时序列化到硬盘。
再讲故事,前几天在做一个模型的交叉验证训练时,通过model.state_dict()保存了每一组交叉验证模型的参数,后根据效果选择准确率最佳的模型load回去,结果每一次都是最后一个模型,从地址来看,每一个保存的state_dict()都具有不同的地址,但进一步发现state_dict()下的各个模型参数的地址是共享的,而我又使用了in-place的方式重置模型参数,进而导致了上述问题。
补充:pytorch中state_dict的理解
在PyTorch中,state_dict是一个Python字典对象(在这个有序字典中,key是各层参数名,value是各层参数),包含模型的可学习参数(即权重和偏差,以及bn层的的参数) 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息。
其实看了如下代码的输出应该就懂了
import torch
import torch.nn as nn
import torchvision
import numpy as np
from torchsummary import summary
# Define model
class TheModelClass(nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# Initialize model
model = TheModelClass()
# Initialize optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor,"\t", model.state_dict()[param_tensor].size())
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
输出如下:
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2238501264336, 2238501329800, 2238501330016, 2238501327136, 2238501328576, 2238501329728, 2238501327928, 2238501327064, 2238501330808, 2238501328288]}]
我是刚接触深度学西的小白一个,希望大佬可以为我指出我的不足,此博客仅为自己的笔记!!!!
补充:pytorch保存模型时报错***object has no attribute 'state_dict'
定义了一个类BaseNet并实例化该类:
net=BaseNet()
保存net时报错 object has no attribute 'state_dict'
torch.save(net.state_dict(), models_dir)
原因是定义类的时候不是继承nn.Module类,比如:
class BaseNet(object):
def __init__(self):
把类定义改为
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://www.cnblogs.com/LukeStepByStep/p/11248361.html
猜你喜欢
- python实现超市扫码仪计费的程序主要是使用超市扫码仪扫商品的条形码,读取商品信息,实现计费功能。主要用到的技术是串口通信,数据库的操作,
- 首先呢,需要有两个mysql服务器。如果做测试的话可以在同一台机器上装两个mysql服务程序,注意要两个运行程序的端口不能一样。我用的是一个
- 在自动化中, Selenium 自动化测试中有一个名字经常被提及 PageObject( 思想与面向对象的特征相 同 ) ,通常 PO 模型
- 假设有这样一个任务,希望对某个文件夹(包括所有子文件夹与文件)中的所有文件进行处理。这就需要遍历整理目录树, 处理遇到的每个文件。impor
- 一、Python开机自动运行假如Python自启动脚本为 auto.py 。那么用root权限编辑以下文件:sudo vim /etc/rc
- 下面是asp代码实现列出sql数据库中存储过程的功能,可自行添加其它功能:< HTML >< 
- 本文实例讲述了Python及Django框架生成二维码的方法。分享给大家供大家参考,具体如下:一、包的安装和简单使用1.1 用Python来
- 以前的服务器,由于内存的价格过高,一般配置的内存不是很多,超过4GB的当然就不多了.现在的服务器,配置超过4GB就很多,在配作SQL 数据库
- “Be conservative in what you send; be liberal in what you accept. &nbs
- 前言:因为研究工作的需要,要更改激活函数以适应自己的网络模型,但是单纯的函数替换会训练导致不能收敛。这里还有些不清楚为什么,希望有人可以给出
- CGArt®2008“贺岁刊”玉鼠闹春,700页再造巅峰本期CGArt杂志信息:下载地址:http://cgart.cgfi
- 接着上一篇,统一思想,遵循标准。如何遵循标准,其实标准有很多,结构标准,表现标准,行为标准。选择标准规范,就优先选择W3C推荐的标准。结构标
- 这篇文章将介绍在Python中使用 "frozenset "函数的指南,该函数返回一个新的frozenset类型的Pyt
- 本文实例讲述了Python3爬虫学习之应对网站反爬虫机制的方法。分享给大家供大家参考,具体如下:如何应对网站的反爬虫机制在访问某些网站的时候
- 如下所示:'''@author: Jacobpc'''import osimport sys
- 用python写了一个简单版本的textrank,实现提取关键词的功能。import numpy as np import jieba im
- 文件目录的创建和删除package mainimport( "fmt" "os")func main
- 1. 生成源码# -*- coding: utf-8 -*-import randomdef generate_verification_c
- 本文实例讲述了JS实现简易图片轮播效果的方法。分享给大家供大家参考。具体如下:这里使用JS制作简易图片轮播效果:制作比较粗糙,使用的图片是w
- 本文实例讲述了Python3实现的反转单链表算法。分享给大家供大家参考,具体如下:反转一个单链表。方案一:迭代# Definition fo