pytorch模型转onnx模型的方法详解
作者:挣扎的笨鸟 发布时间:2021-07-20 06:36:37
学习目标
1.掌握pytorch模型转换到onnx模型
2.顺利运行onnx模型
3.比对onnx模型和pytorch模型的输出结果
学习大纲
pytorch模型转换onnx模型
运行onnx模型
onnx模型输出与pytorch模型比对
学习内容
前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装
1 . pytorch 转 onnx
pytorch 转 onnx 只需要一个函数 torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数说明:
model——需要导出的pytorch模型
args——模型的输入参数,满足输入层的shape正确即可。
path——输出的onnx模型的位置。例如‘yolov5.onnx’。
export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose——是否打印模型转换信息。default=False。
input_names——输入节点名称。default=None。
output_names——输出节点名称。default=None。
do_constant_folding——是否使用常量折叠(不了解),默认即可。default=True。
dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。
格式如下 :
1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
2)仅dict<int, string> dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}}
3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]}opset_version——opset的版本,低版本不支持upsample等操作。
import torch
import torch.nn
import onnx
model = torch.load('best.pt')
model.eval()
input_names = ['input']
output_names = ['output']
x = torch.randn(1,3,32,32,requires_grad=True)
torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')
2 . 运行onnx模型
检查onnx模型,并使用onnxruntime运行。
import onnx
import onnxruntime as ort
model = onnx.load('best.onnx')
onnx.checker.check_model(model)
session = ort.InferenceSession('best.onnx')
x=np.random.randn(1,3,32,32).astype(np.float32) # 注意输入type一定要np.float32!!!!!
# x= torch.randn(batch_size,chancel,h,w)
outputs = session.run(None,input = { 'input' : x })
参数说明:
output_names: default=None
用来指定输出哪些,以及顺序
若为None,则按序输出所有的output,即返回[output_0,output_1]
若为[‘output_1’,‘output_0’],则返回[output_1,output_0]
若为[‘output_0’],则仅返回[output_0:tensor]input:dict
可以通过session.get_inputs().name获得名称
其中key值要求与torch.onnx.export中设定的一致
3.onnx模型输出与pytorch模型比对
import numpy as np
np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)
如前所述,经验表明,ONNX 模型的运行效率明显优于原 PyTorch 模型,这似乎是源于 ONNX 模型生成过程中的优化,这也导致了模型的生成过程比较耗时,但整体效率依旧可观。
此外,根据对 ONNX 模型和 PyTorch 模型运行结果的统计分析(误差的均值和标准差),可以看出 ONNX 模型的运行结果误差很小、基本可靠。
内容参考:https://zhuanlan.zhihu.com/p/422290231
来源:https://blog.csdn.net/weixin_38989668/article/details/123840882


猜你喜欢
- 引入大家在使用谷歌或者百度搜索时,输入搜索内容时,谷歌总是能提供非常好的拼写检查,比如你输入 speling,谷歌会马上返回 spellin
- 最近因为要写一个项目的接口,需要远程的连接oracle数据库,刚开始的时候因为我本地只装了MySQL,所以用就连接了本地MySQL,接口大体
- 前言本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理。以下文章来源于Python进击者 ,
- 本篇文章主要介绍了python OpenCV学习笔记之绘制直方图的方法,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来
- SQL SERVER支持的字符串函数内容:LEN(string)函数LOWER(string)函数UPPER (string)函数LTRIM
- Python的文件类型介绍:.py python的源代码文件.pyc Python源代码import后,编译生成的字节码.pyo Pytho
- 今天写项目的时候用到ant design中的日期组件,但是由于用ant design日期组件取得的值是moment类型,而往数据库中保存需要
- 如何做一个可以让人家申请使用的计数器? 好了,我们来做一个与页面分离的计数器,是文本型的啦。这也很简单,
- 目前有好几种方法可以将python文件打包成exe应用程序文件,例如py2exe,pyinstaller等,比较下来,还是觉得pyinsta
- golang并没有像C语言一样提供三元表达式。三元表达式的好处是可以用一行代码解决原本需要多行代码才能完成的功能,让冗长的代码瞬间变得简洁。
- 在使用python通过open()函数来打开文件的时候,传递绝对路径给open()的时候,发现路径参数的内容与想象中的有所出入:由于wind
- 设计与开发之间本有一线界限,但当时代步入又一个十年,这个线变得更加模糊甚至感觉不到它的存在。使用PS设计网页版面,足矣?或许五年前是吧!现在
- 1,filesize()函数返回错误的值。 使用curl将某个页面下载到本地时,需要将下载到的临时文件tmpHtml.txt的内容读取到一个
- 前言matplotlib是基于Python语言的开源项目,旨在为Python提供一个数据绘图包。在使用Python matplotlib库绘
- 一、map 1.基本介绍map 是 key-value 数据结构,又称为字段或者关联数组。类似其它编程语言的集合, 在编程中是经常
- 首先让我们来看看有关 Perl 面向对象编程的三个基本定义:1. 一个“对象”是指一个“有办法知道它是属于哪个类”的简单引用。(
- 前言相信大家在网上经常看到有人秀出各种各样的字符画,对于这个五彩斑斓的世界来说,我们日常看到的都是一些高清的彩色的图片,偶尔来个粗糙的黑白的
- 查询游戏历史成绩最高分前100Sql代码SELECT ps.* FROM cdb_playsgame ps WHERE ps.credits
- 本文将带领大家由浅入深的去窥探一下,这个装饰器到底是何方神圣,看完本篇,装饰器就再也不是难点了.一、什么是装饰器网上有人是这么评价装饰器的,
- datetime 时间包认识 datetime 时间包:date:日期;time:时间;所以 datetime 就是 日期与时间的结合体使用