pytorch模型转onnx模型的方法详解
作者:挣扎的笨鸟 发布时间:2021-07-20 06:36:37
学习目标
1.掌握pytorch模型转换到onnx模型
2.顺利运行onnx模型
3.比对onnx模型和pytorch模型的输出结果
学习大纲
pytorch模型转换onnx模型
运行onnx模型
onnx模型输出与pytorch模型比对
学习内容
前提条件:需要安装onnx 和 onnxruntime,可以通过 pip install onnx 和 pip install onnxruntime 进行安装
1 . pytorch 转 onnx
pytorch 转 onnx 只需要一个函数 torch.onnx.export
torch.onnx.export(model, args, path, export_params, verbose, input_names, output_names, do_constant_folding, dynamic_axes, opset_version)
参数说明:
model——需要导出的pytorch模型
args——模型的输入参数,满足输入层的shape正确即可。
path——输出的onnx模型的位置。例如‘yolov5.onnx’。
export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose——是否打印模型转换信息。default=False。
input_names——输入节点名称。default=None。
output_names——输出节点名称。default=None。
do_constant_folding——是否使用常量折叠(不了解),默认即可。default=True。
dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道。
格式如下 :
1)仅list(int) dynamic_axes={‘input’:[0,2,3],‘output’:[0,1]}
2)仅dict<int, string> dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:{0:‘batch’,1:‘c’}}
3)mixed dynamic_axes={‘input’:{0:‘batch’,2:‘height’,3:‘width’},‘output’:[0,1]}opset_version——opset的版本,低版本不支持upsample等操作。
import torch
import torch.nn
import onnx
model = torch.load('best.pt')
model.eval()
input_names = ['input']
output_names = ['output']
x = torch.randn(1,3,32,32,requires_grad=True)
torch.onnx.export(model, x, 'best.onnx', input_names=input_names, output_names=output_names, verbose='True')
2 . 运行onnx模型
检查onnx模型,并使用onnxruntime运行。
import onnx
import onnxruntime as ort
model = onnx.load('best.onnx')
onnx.checker.check_model(model)
session = ort.InferenceSession('best.onnx')
x=np.random.randn(1,3,32,32).astype(np.float32) # 注意输入type一定要np.float32!!!!!
# x= torch.randn(batch_size,chancel,h,w)
outputs = session.run(None,input = { 'input' : x })
参数说明:
output_names: default=None
用来指定输出哪些,以及顺序
若为None,则按序输出所有的output,即返回[output_0,output_1]
若为[‘output_1’,‘output_0’],则返回[output_1,output_0]
若为[‘output_0’],则仅返回[output_0:tensor]input:dict
可以通过session.get_inputs().name获得名称
其中key值要求与torch.onnx.export中设定的一致
3.onnx模型输出与pytorch模型比对
import numpy as np
np.testing.assert_allclose(torch_result[0].detach().numpu(),onnx_result,rtol=0.0001)
如前所述,经验表明,ONNX 模型的运行效率明显优于原 PyTorch 模型,这似乎是源于 ONNX 模型生成过程中的优化,这也导致了模型的生成过程比较耗时,但整体效率依旧可观。
此外,根据对 ONNX 模型和 PyTorch 模型运行结果的统计分析(误差的均值和标准差),可以看出 ONNX 模型的运行结果误差很小、基本可靠。
内容参考:https://zhuanlan.zhihu.com/p/422290231
来源:https://blog.csdn.net/weixin_38989668/article/details/123840882
猜你喜欢
- 一、介绍QQ空间相册的个性化利器,能对照片进行效果的优化、文字编辑等等。从设计上使用了创新的手法,尽量减少用户的思考。其中,通过界面的特殊表
- import osfrom PIL import Image#批量剪切目录下图片for j in range(10,121): &
- 一、Python解释器 安装Windows平台下载地址 https://www.python.org/ftp/python/3.9.5/py
- 使用ASP生成图片彩色校验码49行代码,三个文件 Asp文件:Co
- 本文介绍的MySQL数据库的出错代码表,依据MySQL数据库头文件mysql/include/mysqld_error.h整理而成。详细内容
- 1 、据说python3就没有这个问题了2 、u'字符串' 代表是unicode格式的数据,路径最好写成这个格式,别直接跟字
- 实例如下所示:from pandas import *from random import *df = DataFrame(columns=
- Timeloop是一个库,可用于运行多周期任务。这是一个简单的库,使用decorator模式在线程中运行标记函数。首先安装timeloop库
- 生成HTML方法主要步骤只有两个:一、获取要生成的html文件的内容二、将获取的html文件内容保存为html文件我在这里主要说明的只是第一
- 模板引擎说明:模板文件就是按照一定的规则书写的展示效果的HTML文件 模板引擎就是负责按照指定规则进行替换的工具模板引擎选择jinja2一、
- 1. 新建.py文件# pip install kafka-pythonfrom kafka import KafkaConsumerimp
- 1.连接测试连接是否成功:import redisr = redis.Redis(host='192.168.136.102'
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 15 - SlidersMooTools 1.2的
- 本文实例讲述了Python保存最后N个元素的方法。分享给大家供大家参考,具体如下:问题:希望在迭代或是其他形式的处理过程中对最后几项记录做一
- 如题,在控制台运行python manage.py startapp sales 建立一个应用报错异常1.应用名不能包含下划线等字符 所以a
- TensorFlow官网给的cifar-10教程,是卷积神经网络入门的好例子,有时想直接拿这个模型来跑自己的数据,却发现他的数据类型不是常见
- jsonp方式一:指定返回方法# 后端def view(request): callback = request.GET.get
- 这最近在PJ的function库里看到的这个函数,感觉思路差了点,不过相对比较完美,只是闭合标签时的顺序问题,呵呵 修改一下数组arrTag
- 从过往MySQL数据库生产环境的维护工作中,总结的一些小经验和知识,未必有多深奥,但是对我们消除隐患,确保MySQL数据库生产环境四个9的作
- Python有自己内置的标准GUI库--Tkinter,只要安装好Python就可以调用。今天学习到了图形界面设计的问题,刚开始就卡住了。为