Pytorch通过保存为ONNX模型转TensorRT5的实现
作者:小关学长 发布时间:2023-10-22 13:45:27
标签:Pytorch,ONNX,TensorRT5
1 Pytorch以ONNX方式保存模型
def saveONNX(model, filepath):
'''
保存ONNX模型
:param model: 神经网络模型
:param filepath: 文件保存路径
'''
# 神经网络输入数据类型
dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
torch.onnx.export(model, dummy_input, filepath, verbose=True)
2 利用TensorRT5中ONNX解析器构建Engine
def ONNX_build_engine(onnx_file_path):
'''
通过加载onnx文件,构建engine
:param onnx_file_path: onnx文件路径
:return: engine
'''
# 打印日志
G_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
builder.max_batch_size = 100
builder.max_workspace_size = 1 << 20
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
# 保存计划文件
# with open(engine_file_path, "wb") as f:
# f.write(engine.serialize())
return engine
3 构建TensorRT运行引擎进行预测
def loadONNX2TensorRT(filepath):
'''
通过onnx文件,构建TensorRT运行引擎
:param filepath: onnx文件路径
'''
# 计算开始时间
Start = time()
engine = self.ONNX_build_engine(filepath)
# 读取测试集
datas = DataLoaders()
test_loader = datas.testDataLoader()
img, target = next(iter(test_loader))
img = img.numpy()
target = target.numpy()
img = img.ravel()
context = engine.create_execution_context()
output = np.empty((100, 10), dtype=np.float32)
# 分配内存
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]
# pycuda操作缓冲区
stream = cuda.Stream()
# 将输入数据放入device
cuda.memcpy_htod_async(d_input, img, stream)
# 执行模型
context.execute_async(100, bindings, stream.handle, None)
# 将预测结果从从缓冲区取出
cuda.memcpy_dtoh_async(output, d_output, stream)
# 线程同步
stream.synchronize()
print("Test Case: " + str(target))
print("Prediction: " + str(np.argmax(output, axis=1)))
print("tensorrt time:", time() - Start)
del context
del engine
补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT
近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。
后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。
是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。
来源:https://blog.csdn.net/qq_38003892/article/details/89314108


猜你喜欢
- 一.页面样式二.数据库 三.前端页面代码 <template> <el-tree :props="pr
- 本文实例讲述了JS模拟简易滚动条效果的方法。分享给大家供大家参考,具体如下:使用Js模拟滚动条。简易模式,类似手机上常见的滚动条。效果如下:
- 单线程执行python的内置模块提供了两个内置模块:thread和threading,thread是源生模块,threading是扩展模块,
- 死锁的原理非常简单,用一句话就可以描述完。就是当多线程访问多个锁的时候,不同的锁被不同的线程持有,它们都在等待其他线程释放出锁来,于是便陷入
- 似乎讨论分页的人很少,难道大家都沉迷于limit m,n?在有索引的情况下,limit m,n速度足够,可是在复杂条件搜索时,where s
- 代码代码很简单,主要是为了熟悉Selenium这个库的函数,为后续的短信轰炸做个铺垫from selenium import webdriv
- 开发环境解释器版本: python 3.8代码编辑器: pycharm 2021.2第三方模块requests: pip install r
- ConfigParser模块在python中用来读取配置文件,配置文件的格式跟windows下的ini配置文件相似,可以包含一个或多个节(s
- 1、启动SQL Server Management Studio,以Windows身份验证方式登录。2、在对象资源管理器窗口中,右键单击服务
- 本节内容:1.前言2.相关概念3.Python中的默认编码4.Python2与Python3中对字符串的支持5.字符编码转换一、前言Pyth
- 什么是集合1.集合是一个可变容器2.集合内的数据对象都是唯一的(不能重复)3.集合是无序的存储结构,集合内的数据没有先后关系4.集合是可迭代
- 根据国务院文件,5.19-5.21为全国哀悼日,在此期间,全国和各驻外机构下半旗志哀,停止公共娱乐活动,外交部和我国驻外使领馆设立吊唁簿。5
- 1、pyecharts绘制时间轮播柱形图from random import randintfrom pyecharts import op
- node有一个模块叫n(这名字可够短的。。。),是专门用来管理node.js的版本的。首先安装n模块:npm install -g n第二步
- 1.过滤器的使用1.过滤器和测试器在Python中,如果需要对某个变量进行处理,我们可以通过函数来实现。在模板中,我们则是通过过滤器来实现的
- 本文将展示一个开源JavaScript库,该脚本库给AJAX应用程序带来了书签和后退按钮支持。在学习完这个教程后,开发人员将能够获得对一个A
- 又一个js加密工具:js混淆,完整源代码如下,有点长呵呵:<HTML><HEAD><TITLE>Cunf
- ThinkPHP支持多种php模板引擎,可以根据个人需要加以配置。下面我们以Smarty模板引擎为例,给大家说说具体的操作流程!首先去Sma
- 1、mysql 导出文件:SELECT `pe2e_user_to_company`.company_name, `pe2e_user_to
- 1,判断图像清晰度,明暗,原理,Laplacian算法。偏暗的图片,二阶导数小,区域变化小;偏亮的图片,二阶导数大,区域变化快。import