pytorch加载自定义网络权重的实现
作者:wuming无名 发布时间:2022-06-16 14:39:10
标签:pytorch,加载,网络,权重
在将自定义的网络权重加载到网络中时,报错:
AttributeError: 'dict' object has no attribute 'seek'. You can only torch.load from a file that is seekable. Please pre-load the data into a buffer like io.BytesIO and try to load from it instead.
我们一步一步分析。
模型网络权重保存额代码是:torch.save(net.state_dict(),'net.pkl')
(1)查看获取模型权重的源码:
pytorch源码:net.state_dict()
def state_dict(self, destination=None, prefix='', keep_vars=False):
r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> module.state_dict().keys()
['bias', 'weight']
"""
将网络中所有的状态保存到一个字典中了,我自己构建的就是一个字典,没问题!
(2)查看保存模型权重的源码:
pytorch源码:torch.save()
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
"""Saves an object to a disk file.
See also: :ref:`recommend-saving-models`
Args:
obj: saved object
f: a file-like object (has to implement write and flush) or a string
containing a file name
pickle_module: module used for pickling metadata and objects
pickle_protocol: can be specified to override the default protocol
.. warning::
If you are using Python 2, torch.save does NOT support StringIO.StringIO
as a valid file-like object. This is because the write method should return
the number of bytes written; StringIO.write() does not do this.
Please use something like io.BytesIO instead.
函数功能是将字典保存为磁盘文件(二进制数据),那么我们在torch.load()时,就是在内存中加载二进制数据,这就是报错点。
解决方案:将字典保存为BytesIO文件之后,模型再net.load_state_dict()
#b为自定义的字典
torch.save(b,'new.pkl')
net.load_state_dict(torch.load(b))
解决方法很简单,主要记录解决思路。
来源:https://blog.csdn.net/qq_34789262/article/details/83376374


猜你喜欢
- 如果没有设置分页,django-rest-framework 会将所有资源类表序列化后返回,如果资源很多,就会对网站性能造成影响。为此,我们
- 手动备份1)cmd控制台:mysqldump -uroot -proot 数据库名 [表名1,表名2...] > 文件路径比如:把 d
- 本文针对Python time模块进行分类学习,希望对大家的学习有所帮助。一.壁挂钟时间1.time()time模块的核心函数time(),
- javascript sort()排序用法sort() 方法用于对数组的元素进行排序,并返回数组。默认排序顺序是根据字符串UniCode码。
- 引言人工智能是计算机科学中一个非常热门的领域,近年来得到了越来越多的关注。它通过模拟人类思考过程和智能行为来实现对复杂任务的自主处理和学习,
- 使用诸如Lock、RLock、Semphore之类的锁原语时,必须多加小心,锁的错误使用很容易导致死锁或相互竞争。依赖锁的代码应该保证当出现
- scrapy是一个基于Twisted的异步处理框架,可扩展性很强。优点此处不再一一赘述。下面介绍一些概念性知识,帮助大家理解scrapy。一
- 定义和用法fopen() 函数打开文件或者 URL。如果打开失败,本函数返回 FALSE。语法fopen(filename,mode,inc
- 如果一个数字能表示成 p^q,且p是一个素数,q为大于1的正整数,则此数字就是超级素数幂。 param number: 测试该数字是否是超级
- 1. 概述若要将数据库移动或更改到同一计算机的不同 SQL Server 实例,分离和附加数据库会很有用;用户可以分离数据库的数据和事务日志
- 相信没有人不知道 Firebug 是什么东西,但有时候我们糟糕的代码不想让同行轻松的使用 F12 就能一览无遗。那么怎么办呢?这里有个猥琐的
- 无论是在小得可怜的免费数据库空间或是大型电子商务网站,合理的设计表结构、充分利用空间是十分必要的。这就要求我们对数据库系统的常用数据类型有充
- 本文实例为大家分享了python定时按日期备份MySQL数据并压缩的具体代码,供大家参考,具体内容如下#-*- coding:utf-8 -
- 大家都知道系统存储过程是无法用工具导出的(大家可以试试 >任务>生成SQL脚本) 因为系统存储过程一般是不让开发人员修改的。 需
- 一、什么是执行计划(explain plan) 执行计划:一条查询语句在ORACLE中的执行过程或访问路径的描述。 二、如何查看执行计划 1
- 昨天晚上群里有朋友采集网页时发现file_get_contents 获得的网页保存到本地为乱码,响应的header 里 Content-En
- 在上一篇文章中讲解了什么是反射,以及利用反射可以获取程序集里面的哪些内容。在平时的项目中,可能会遇到项目需要使用多种数据库,这篇文章中将会讲
- 前言这篇文章主要介绍了Python 字符串去除空格的6种方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,来
- python 迭代器与生成器,装饰器迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。迭代器有两个基本的方法:iter()
- 1 新建类库MyTestDLL2 右击项目“MyTestDLL”-》属性-》生成-》勾选“为COM互操作注册”3 打开 AssemblyIn