从Pytorch模型pth文件中读取参数成numpy矩阵的操作
作者:木盏 发布时间:2021-12-27 11:05:53
标签:Pytorch,pth,numpy,矩阵
目的:
把训练好的pth模型参数提取出来,然后用其他方式部署到边缘设备。
Pytorch给了很方便的读取参数接口:
nn.Module.parameters()
直接看demo:
from torchvision.models.alexnet import alexnet
model = alexnet(pretrained=True).eval().cuda()
parameters = model.parameters()
for p in parameters:
numpy_para = p.detach().cpu().numpy()
print(type(numpy_para))
print(numpy_para.shape)
上面得到的numpy_para就是numpy参数了~
Note:
model.parameters()是以一个生成器的形式迭代返回每一层的参数。所以用for循环读取到各层的参数,循环次数就表示层数。
而每一层的参数都是torch.nn.parameter.Parameter类型,是Tensor的子类,所以直接用tensor转numpy(即p.detach().cpu().numpy())的方法就可以直接转成numpy矩阵。
方便又好用,爆赞~
补充:pytorch训练好的.pth模型转换为.pt
将python训练好的.pth文件转为.pt
import torch
import torchvision
from unet import UNet
model = UNet(3, 2)#自己定义的网络模型
model.load_state_dict(torch.load("best_weights.pth"))#保存的训练模型
model.eval()#切换到eval()
example = torch.rand(1, 3, 320, 480)#生成一个随机输入维度的输入
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("model.pt")
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://muzhan.blog.csdn.net/article/details/113066030


猜你喜欢
- 前言在进行数据库连接之前,一般都需要导入依赖的库,通过专门的库去处理对应的数据库连接,所以没安装对应的处理库的话,需要先进行安装、再导入,导
- 前言前面我们介绍了 pandas 的基础语法操作,下面我们开始介绍 pandas 的数据读写操作。pandas 的 IO API 是一组顶层
- 1、安装执行命令pip install virtualenv为了使用virtualenv更方便,可以借助 virtualenvwrapper
- 概述前段时间突然发现,我之前对git stash的使用都是错误的。具体说来,我是这么使用的:在远端有新的提交,需要git pull来拉取合并
- 引子首先说 正则表达式是什么?正则表达式,又称正规表示式、正规表示法、正规表达式、规则表达式、常规表示法(英语:Regular Expres
- MySQL Group By用法我们现在回到函数上。记得我们用 SUM 这个指令来算出所有的 Sales (营业额)吧!如果我们的需求变成是
- 1.多边形的绘制案例# 多边形的绘制案例import turtledef main():turtle.color("green&q
- 什么是中间件我们从一个简单的例子开始。高流量的站点通常需要将Django部署在负载平衡proxy之后。 这种方式将带来一些复杂性,其一就是每
- 简介有兴趣可以看看: 解释性语言+动态类型语言+强类型语言交互模式:(主要拿来试验,可以试试 ipython)$python>>
- 最近项目用到了bootstrap框架,其中前端用的校验,采用的是bootstrapvalidator插件,也是非常强大的一款插件。我这里用的
- 1. @@rowcount: 获取受影响行数 代码如下:update SNS_TopicData set TopicCount=TopicC
- 背景Python 作为一门成熟的编程语言,拥有无数优秀的第三方包以方便开发者能够快速地构建应用。一般来说,如果你开发了一个 Python 软
- Python获取当前时间_获取格式化时间:Python获取当前时间:使用 time.time( ) 获取到距离1970年1月1日的秒数(浮点
- 如下所示:# coding = utf-8import requestsimport jsonhost = "http://47.
- 本文实例讲述了Python3.5 Pandas模块之DataFrame用法。分享给大家供大家参考,具体如下:1、DataFrame的创建(1
- 知识点: 1、拼接SQL 2、UNION ALL 3、EXEC 其代码如下: 代码如下:--测试示例 declare @sql
- 哪个Python版本?当我提及Python,所指的就是CPython 2(准确的是2.7).我会显式提醒那些相同的代码在CPython 3
- 前言python对动态验证码、滑动验证码的降噪和识别,在各种自动化操作中,我们经常要遇到沿跳过验证码的操作,而对于验证码的降噪和识别,的确困
- 协程协程简单来说就是一个更加轻量级的线程,并且不由操作系统内核管理,完全由程序所控制(在用户态执行)。协程在子程序内部是可中断的,然后转而执
- django创建自定义模板处理器:一、需求来源:在django开发中,页面是通过template(模板)进行渲染的,对于一些数据,可以通过{