Python torch.onnx.export用法详细介绍
作者:Kmaeii 发布时间:2022-04-28 22:07:33
函数原型
参数介绍
mode (torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction)
需要转换的模型,支持的模型类型有:torch.nn.Module, torch.jit.ScriptModule or torch.jit.ScriptFunction
args (tuple or torch.Tensor)
args可以被设置成三种形式
1.一个tuple
args = (x, y, z)
这个tuple应该与模型的输入相对应,任何非Tensor的输入都会被硬编码入onnx模型,所有Tensor类型的参数会被当做onnx模型的输入。
2.一个Tensor
args = torch.Tensor([1, 2, 3])
一般这种情况下模型只有一个输入
3.一个带有字典的tuple
args = (x,
{'y': input_y,
'z': input_z})
这种情况下,所有字典之前的参数会被当做“非关键字”参数传入网络,字典种的键值对会被当做关键字参数传入网络。如果网络中的关键字参数未出现在此字典中,将会使用默认值,如果没有设定默认值,则会被指定为None。
NOTE:
一个特殊情况,当网络本身最后一个参数为字典时,直接在tuple最后写一个字典则会被误认为关键字传参。所以,可以通过在tuple最后添加一个空字典来解决。
#错误写法:
torch.onnx.export(
model,
(x,
# WRONG: will be interpreted as named arguments
{y: z}),
"test.onnx.pb")
# 纠正
torch.onnx.export(
model,
(x,
{y: z},
{}),
"test.onnx.pb")
f
一个文件类对象或一个路径字符串,二进制的protocol buffer将被写入此文件
export_params (bool, default True)
如果为True则导出模型的参数。如果想导出一个未训练的模型,则设为False
verbose (bool, default False)
如果为True,则打印一些转换日志,并且onnx模型中会包含doc_string信息。
training (enum, default TrainingMode.EVAL)
枚举类型包括:
TrainingMode.EVAL - 以推理模式导出模型。
TrainingMode.PRESERVE - 如果model.training为False,则以推理模式导出;否则以训练模式导出。
TrainingMode.TRAINING - 以训练模式导出,此模式将禁止一些影响训练的优化操作。
input_names (list of str, default empty list)
按顺序分配给onnx图的输入节点的名称列表。
output_names (list of str, default empty list)
按顺序分配给onnx图的输出节点的名称列表。
operator_export_type (enum, default None)
默认为OperatorExportTypes.ONNX, 如果Pytorch built with DPYTORCH_ONNX_CAFFE2_BUNDLE,则默认为OperatorExportTypes.ONNX_ATEN_FALLBACK。
枚举类型包括:
OperatorExportTypes.ONNX - 将所有操作导出为ONNX操作。
OperatorExportTypes.ONNX_FALLTHROUGH - 试图将所有操作导出为ONNX操作,但碰到无法转换的操作(如onnx未实现的操作),则将操作导出为“自定义操作”,为了使导出的模型可用,运行时必须支持这些自定义操作。支持自定义操作方法见链接。
OperatorExportTypes.ONNX_ATEN - 所有ATen操作导出为ATen操作,ATen是Pytorch的内建tensor库,所以这将使得模型直接使用Pytorch实现。(此方法转换的模型只能被Caffe2直接使用)
OperatorExportTypes.ONNX_ATEN_FALLBACK - 试图将所有的ATen操作也转换为ONNX操作,如果无法转换则转换为ATen操作(此方法转换的模型只能被Caffe2直接使用)。例如:
# 转换前:
graph(%0 : Float):
%3 : int = prim::Constant[value=0]()
# conversion unsupported
%4 : Float = aten::triu(%0, %3)
# conversion supported
%5 : Float = aten::mul(%4, %0)
return (%5)
# 转换后:
graph(%0 : Float):
%1 : Long() = onnx::Constant[value={0}]()
# not converted
%2 : Float = aten::ATen[operator="triu"](%0, %1)
# converted
%3 : Float = onnx::Mul(%2, %0)
return (%3)
opset_version (int, default 9)
默认是9。值必须等于_onnx_main_opset或在_onnx_stable_opsets之内。具体可在torch/onnx/symbolic_helper.py中找到。例如:
_default_onnx_opset_version = 9
_onnx_main_opset = 13
_onnx_stable_opsets = [7, 8, 9, 10, 11, 12]
_export_onnx_opset_version = _default_onnx_opset_version
do_constant_folding (bool, default False)
是否使用“常量折叠”优化。常量折叠将使用一些算好的常量来优化一些输入全为常量的节点。
example_outputs (T or a tuple of T, where T is Tensor or convertible to Tensor, default None)
当需输入模型为ScriptModule 或 ScriptFunction时必须提供。此参数用于确定输出的类型和形状,而不跟踪(tracing )模型的执行。
dynamic_axes (dict<string, dict<python:int, string>> or dict<string, list(int)>, default empty dict)
通过以下规则设置动态的维度:
KEY(str) - 必须是input_names或output_names指定的名称,用来指定哪个变量需要使用到动态尺寸。
VALUE(dict or list) - 如果是一个dict,dict中的key是变量的某个维度,dict中的value是我们给这个维度取的名称。如果是一个list,则list中的元素都表示此变量的某个维度。
具体可参考如下示例:
class SumModule(torch.nn.Module):
def forward(self, x):
return torch.sum(x, dim=1)
# 以动态尺寸模式导出模型
torch.onnx.export(SumModule(), (torch.ones(2, 2),), "onnx.pb",
input_names=["x"], output_names=["sum"],
dynamic_axes={
# dict value: manually named axes
"x": {0: "my_custom_axis_name"},
# list value: automatic names
"sum": [0],
})
### 导出后的节点信息
##input
input {
name: "x"
...
shape {
dim {
dim_param: "my_custom_axis_name" # axis 0
}
dim {
dim_value: 2 # axis 1
...
##output
output {
name: "sum"
...
shape {
dim {
dim_param: "sum_dynamic_axes_1" # axis 0
...
keep_initializers_as_inputs (bool, default None)
NONE
custom_opsets (dict<str, int>, default empty dict)
NONE
Torch.onnx.export执行流程:
1、如果输入到torch.onnx.export的模型是nn.Module类型,则默认会将模型使用torch.jit.trace转换为ScriptModule
2、使用args参数和torch.jit.trace将模型转换为ScriptModule,torch.jit.trace不能处理模型中的循环和if语句
3、如果模型中存在循环或者if语句,在执行torch.onnx.export之前先使用torch.jit.script将nn.Module转换为ScriptModule
4、模型转换成onnx之后,预测结果与之前会有稍微的差别,这些差别往往不会改变模型的预测结果,比如预测的概率在小数点之后五六位有差别。
来源:https://blog.csdn.net/Dteam_f/article/details/122487634
猜你喜欢
- 本文给大家介绍PHP中Http协议post请求参数,具体内容如下所示:WEB开发中信息基本全是在POST与GET请求与响应中进行,GET因其
- SQL Server 联机帮助给出了详细说明。 -->目录 -->SQL Server架构 --&
- 遇到一个很奇怪的现象,在给页面添加“打印”按钮时,发现网页在IE6下居然不能打印,弹出一个对话框,遇到脚本错误。查看错误详细:定位到 url
- ASP.NET WEB FORMS 给开发者提供了极好的事件驱动开发模式。然而这种简单的应用程序开发模式却给我
- use 数据库 go EXEC sp_changeobjectowner ‘原表的所有者.表名',现在的所有者例如: exec sp
- 有这样一类文章标题,喜欢学习的人肯定见过:使用Google的7个技巧Web设计中9个常见的可用性错误Adobe Photoshop 75个技
- javascript var item = document.getElementById(""); var text
- 在利用python进行flask等开发过程中经常需要配置虚拟环境以方便针对不同的项目需求配置不同的生产环境。在python3.3之前,需要利
- 图片提取为了方便技术展示,我们选取素材为演员杨紫的一段演讲视频,用例仅为技术交流演示使用,不针对任何指定人。为达到我们AI换脸的目的,我们首
- 基于signal模块实现signal包负责在Python程序内部处理信号,典型的操作包括预设信号处理函数,暂停并等待信号,以及定时发出SIG
- <html> <head> <title>Untitled Document</title>
- 看了大神统计voc数据集标签框后,针对自己标注数据集,灵活应用 ,感谢!看代码吧~import reimport osimport xml.
- 原来是在系统上出了问题.是2003的IIS出现了问题,因为是2003的系统,它对ASP的上传文件做出了200K的限制,解决问题方法如下 :
- messageboxtkinter.messagebox中封装了多种消息框,其输入参数统一为title, message以及其他参数。其中t
- 导语九月初家里的熊孩子终于开始上学了!半个月过去了,小孩子每周都会带着一堆的数学作业回来,哈哈哈哈~真好,在家做作业就没时间打扰我写代码了。
- 集合特点:集合对象是一组无序排列的可哈希的值:集合成员可以做字典的键,与列表和元组不同,集合无法通过数字进行索引。此外,集合中的元素不能重复
- 前言:在软件测试中,为项目编写接口自动化用例已成为测试人员常驻的测试工作。本文以python为例,基于笔者曾使用过的三种用例数据读取方法:x
- 可变长参数GO语言允许一个函数把任意数量的值作为参数,GO语言内置了**...操作符,在函数的最后一个形参才能使用...**操作符,使用它必
- 发现问题一个作业报错,报错信息如下,从错误信息根本看不出为什么出错,手工运行作业又成功了。一时不清楚什么原因导致作业出错。MessageEx
- 说在前头最近在做毕设,题目是道路拥堵预测系统,学长建议我使用SVM算法进行预测,但是在此之前需要把Excel中的数据进行二次处理,原始数据不