PyTorch模型转换为ONNX格式实现过程详解
作者:实力 发布时间:2022-03-18 00:54:18
1. 安装依赖
将PyTorch模型转换为ONNX格式可以使它在其他框架中使用,如TensorFlow、Caffe2和MXNet
首先安装以下必要组件:
Pytorch
ONNX
ONNX Runtime(可选)
建议使用conda
环境,运行以下命令来创建一个新的环境并激活它:
conda create -n onnx python=3.8
conda activate onnx
接下来使用以下命令安装PyTorch和ONNX:
conda install pytorch torchvision torchaudio -c pytorch
pip install onnx
可选地,可以安装ONNX Runtime以验证转换工作的正确性:
pip install onnxruntime
2. 准备模型
将需要转换的模型导出为PyTorch模型的.pth
文件。使用PyTorch内置的函数加载它,然后调用eval()方法以保证close状态:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.onnx
import torchvision.transforms as transforms
import torchvision.datasets as datasets
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
PATH = './model.pth'
torch.save(net.state_dict(), PATH)
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
3. 调整输入和输出节点
现在需要定义输入和输出节点,这些节点由导出的模型中的张量名称表示。将使用PyTorch内置的函数torch.onnx.export()
来将模型转换为ONNX格式。下面的代码片段说明如何找到输入和输出节点,然后传递给该函数:
input_names = ["input"]
output_names = ["output"]
dummy_input = torch.randn(batch_size, input_channel_size, input_height, input_width)
# Export the model
torch.onnx.export(model, dummy_input, "model.onnx", verbose=True,
input_names=input_names, output_names=output_names)
4. 运行转换程序
运行上述程序时可能遇到错误信息,其中包括一些与节点的名称和形状相关的警告,甚至还有Python版本、库、路径等信息。在处理完这些错误后,就可以转换PyTorch模型并立即获得ONNX模型了。输出ONNX模型的文件名是model.onnx
。
5. 使用后端框架测试ONNX模型
现在,使用ONNX模型检查一下是否成功地将其从PyTorch导出到ONNX,可以使用TensorFlow或Caffe2进行验证。以下是一个简单的示例,演示如何使用TensorFlow来加载和运行该模型:
import onnxruntime as rt
import numpy as np
sess = rt.InferenceSession('model.onnx')
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
np.random.seed(123)
X = np.random.randn(batch_size, input_channel_size, input_height, input_width).astype(np.float32)
res = sess.run([output_name], {input_name: X})
这应该可以顺利地运行,并且输出与原始PyTorch模型具有相同的形状(和数值)。
6. 核对结果
最好的方法是比较PyTorch模型与ONNX模型在不同框架中推理的结果。如果结果完全匹配,则几乎可以肯定地说PyTorch到ONNX转换已经成功。以下是通过PyTorch和ONNX检查模型推理结果的一个小程序:
# Test the model with PyTorch
model.eval()
with torch.no_grad():
Y = model(torch.from_numpy(X)).numpy()
# Test the ONNX model with ONNX Runtime
sess = rt.InferenceSession('model.onnx')
res = sess.run(None, {input_name: X})[0]
# Compare the results
np.testing.assert_allclose(Y, res, rtol=1e-6, atol=1e-6)
来源:https://juejin.cn/post/7221777883160985661
猜你喜欢
- 本文实例讲述了python实现生成Word、docx文件的方法。分享给大家供大家参考,具体如下:http://python-docx.rea
- 1、Introduction之前写过2篇文章,分别是:Mysql主从同步的原理 Myql主从同步实战 基于此,我们再实
- OCR of Hand-written Data using kNNOCR of Hand-written Digits我们的目标是构建一个
- 写了一个小巧的jquery拾色工具,代码简单得不得了,只有这么几行:(function($){ $.fn.pickColor=fu
- 文件内容:excel内容:代码:import xlrdimport jsonimport operatordef read_xlsx(fil
- 在使用AJAX开发网站时,经常有朋友遇到乱码的问题,而且一下子难以找到解决方法。其实解决AJAX中文乱码问题很简单。1、服务端程序:<
- PHP get_html_translation_table() 函数实例输出 htmlspecialchars 函数使用的翻译表:<
- 今天在日常维护一个网站时,发现该网站的留言程序没有经过严格的验证过滤,导致了将近十万条垃圾数据。而其中又不乏重要信息,需要清理数据,以及增加
- 之前刚开始做爬虫的时候遇到过登录验证码问题,看过很多帖子都没有解决我的问题,发现大多数帖子都是治标不治本,于是想分享一下自己的解决方案。本次
- 例如:我们在百度中搜索 词典网,则网址后面的参数就是http://www.baidu.com/s?cl=3&wd=%B4%CA%B5
- 我们知道numpy.ndarray.reshape()是用来改变numpy数组的形状的,但是它的参数会有一些特殊的用法,这里我们进一步说明一
- 前言损失函数在机器学习中用于表示预测值与真实值之间的差距。一般而言,大多数机器学习模型都会通过一定的优化器来减小损失函数从而达到优化预测机器
- 如果说哪个开源程序不需要介绍大家就认识,那一定是phpMyAdmin,一款流行的MySQL数据库的Web管理界面。MySQL是全球最流行的W
- 本文实例为大家分享了python多线程http压力测试的具体代码,供大家参考,具体内容如下#coding=utf-8import sysim
- 图片提取为了方便技术展示,我们选取素材为演员杨紫的一段演讲视频,用例仅为技术交流演示使用,不针对任何指定人。为达到我们AI换脸的目的,我们首
- 和网友们讨论了数组取交集的方法,下面是两个实现arr1=["1","5","6"
- 背景今天有人问我 “为什么数据库中有人推荐使用 int 类型来保存 IP 地址?”。现在(2020年)来看这个东西已经有点过时了,一方面是磁
- 本文实例为大家分享了python生成圆形图片的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*- "
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 8 - Input Filtering Part
- 前言记录CS2000设备使用串口连接以及相关控制。CS2000是一台分光辐射亮度计,也就是可以测量光源的亮度。详细的规格网址参考CS2000