pytorch加载自定义网络权重的实现
作者:wuming无名 发布时间:2022-06-16 14:39:10
标签:pytorch,加载,网络,权重
在将自定义的网络权重加载到网络中时,报错:
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
我们一步一步分析。
模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')
(1)查看获取模型权重的源码:
pytorch源码:net.state_dict()
def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> module.state_dict().keys()
['bias', 'weight']
"""
将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!
(2)查看保存模型权重的源码:
pytorch源码:torch.save()
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string
containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
.. warning::
If you are using Python 2, torch.save does NOT support StringIO.StringIO
as a valid file-like object. This is because the write method should return
the number of bytes written; StringIO.write() does not do this.
Please use something like io.BytesIO instead.
函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。
解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()
#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))
解决方法很简单,主要记录解决思路。
来源:https://blog.csdn.net/qq_34789262/article/details/83376374
0
投稿
猜你喜欢
- 最近在研究tensorflow自带的例程speech_command,顺便学习tensorflow的一些基本用法。其中tensorboard
- 为了庆祝jQuery的四周岁生日, jQuery的团队荣幸的发布了jQuery Javascript库的最新主要版本! 这个版本包含了大量的
- 在应用程序的开发中,有些输入信息是动态的,比如我们要注册一个员工的工作经历,比如下图如果做成死的,只能填写三个,如果是四个呢?或者更多呢,那
- 一. 介绍一个计数器工具提供快速和方便的计数,Counter是一个dict的子类,用于计数可哈希对象。它是一个集合,元素像字典键(key)一
- 任务说明:编写一个钱币定位系统,其不仅能够检测出输入图像中各个钱币的边缘,同时,还能给出各个钱币的圆心坐标与半径。效果代码实现Canny边缘
- 有时候,我们需要将文本转换为图片,比如发长微博,或者不想让人轻易复制我们的文本内容等时候。目前类似的工具已经有了不少,不过我觉得用得都不是很
- 一、分析网页网站的页面是 JavaScript 渲染而成的,我们所看到的内容都是网页加载后又执行了JavaScript代码之后才呈现出来的,
- match()函数只检测RE是不是在string的开始位置匹配, search()会扫描整个string查找匹配, 也就是说match()只
- 概述做日志分析工作的经常需要跟成千上万的日志条目打交道,为了在庞大的数据量中找到特定模式的数据,常常需要编写很多复杂的正则表达式。例如枚举出
- 最近用了pycharm,感觉还不错,就是pandas中Series、DataFrame的plot()方法不显示图片就给我结束了,但是我在ip
- 1、介绍在爬虫中经常会遇到验证码识别的问题,现在的验证码大多分计算验证码、滑块验证码、识图验证码、语音验证码等四种。本文就是识图验证码,识别
- Oracle数据库提供了几种不同的数据库启动和关闭方式,本文将详细介绍这些启动和关闭方式之间的区别以及它们各自不同的功能。 一、启动和关闭O
- 封装Python将多个值用逗号隔开,进行赋值。会将这些值封装成一个tuple返回#示例a = 1,2type(a)结果:<class
- 本文介绍了可以帮助简化 PHP 开发的10个项目,包括框架,类库,工具,代码。1.CakePHP Development Framework
- 1. 引言在日常工作中,大家都需要进行字典的相关操作,对于某些初学者,经常会写一堆繁琐的代码来实现某项简单的功能。本篇文章重点介绍一些在Py
- WARNING:低技术力自己无聊写的哥特字体是最好看的:示例代码:#!usr/bin/env python3# -*- coding:UTF
- 示例1我们将要请求五个不同的url:单线程import timeimport urllib2defget_responses(): &nbs
- 之前总结过flask里的基础知识,现在来总结下flask里的前后端数据交互的知识,这里用的是Ajax一、 post方法1、post方法的位置
- 如下所示:#coding=utf-8#方式一print('*'*20 + '方式一' + '*
- Tkinter复选框Checkbutton是否被选中判断定义一个BooleanVar型数据进行获取复选框状态。>>> im