Pytorch通过保存为ONNX模型转TensorRT5的实现
作者:小关学长 发布时间:2023-10-22 13:45:27
标签:Pytorch,ONNX,TensorRT5
1 Pytorch以ONNX方式保存模型
def saveONNX(model, filepath):
'''
保存ONNX模型
:param model: 神经网络模型
:param filepath: 文件保存路径
'''
# 神经网络输入数据类型
dummy_input = torch.randn(self.config.BATCH_SIZE, 1, 28, 28, device='cuda')
torch.onnx.export(model, dummy_input, filepath, verbose=True)
2 利用TensorRT5中ONNX解析器构建Engine
def ONNX_build_engine(onnx_file_path):
'''
通过加载onnx文件,构建engine
:param onnx_file_path: onnx文件路径
:return: engine
'''
# 打印日志
G_LOGGER = trt.Logger(trt.Logger.WARNING)
with trt.Builder(G_LOGGER) as builder, builder.create_network() as network, trt.OnnxParser(network, G_LOGGER) as parser:
builder.max_batch_size = 100
builder.max_workspace_size = 1 << 20
print('Loading ONNX file from path {}...'.format(onnx_file_path))
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
print('Completed parsing of ONNX file')
print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))
engine = builder.build_cuda_engine(network)
print("Completed creating Engine")
# 保存计划文件
# with open(engine_file_path, "wb") as f:
# f.write(engine.serialize())
return engine
3 构建TensorRT运行引擎进行预测
def loadONNX2TensorRT(filepath):
'''
通过onnx文件,构建TensorRT运行引擎
:param filepath: onnx文件路径
'''
# 计算开始时间
Start = time()
engine = self.ONNX_build_engine(filepath)
# 读取测试集
datas = DataLoaders()
test_loader = datas.testDataLoader()
img, target = next(iter(test_loader))
img = img.numpy()
target = target.numpy()
img = img.ravel()
context = engine.create_execution_context()
output = np.empty((100, 10), dtype=np.float32)
# 分配内存
d_input = cuda.mem_alloc(1 * img.size * img.dtype.itemsize)
d_output = cuda.mem_alloc(1 * output.size * output.dtype.itemsize)
bindings = [int(d_input), int(d_output)]
# pycuda操作缓冲区
stream = cuda.Stream()
# 将输入数据放入device
cuda.memcpy_htod_async(d_input, img, stream)
# 执行模型
context.execute_async(100, bindings, stream.handle, None)
# 将预测结果从从缓冲区取出
cuda.memcpy_dtoh_async(output, d_output, stream)
# 线程同步
stream.synchronize()
print("Test Case: " + str(target))
print("Prediction: " + str(np.argmax(output, axis=1)))
print("tensorrt time:", time() - Start)
del context
del engine
补充知识:Pytorch/Caffe可以先转换为ONNX,再转换为TensorRT
近来工作,试图把Pytorch用TensorRT运行。折腾了半天,没有完成。github中的转换代码,只能处理pytorch 0.2.0的功能(也明确表示不维护了)。和同事一起处理了很多例外,还是没有通过。吾以为,实际上即使勉强过了,能不能跑也是问题。
后来有高手建议,先转换为ONNX,再转换为TensorRT。这个思路基本可行。
是不是这样就万事大吉?当然不是,还是有严重问题要解决的。这只是个思路。
来源:https://blog.csdn.net/qq_38003892/article/details/89314108
0
投稿
猜你喜欢
- Broadcast广播是numpy对不同形状(shape)的数组进行数值计算的方式,对数组的算术运算通常在相应的元素上进行。如果两个数组a和
- 调用JSON.stringify将对象转为对应的字符串时,如果包含时间对象,时间对象会被转换为国家标准时间(ISO),而不是当前国家区域的时
- 正在看的ORACLE教程是:ORACLE8的分区管理。摘要:本篇文章介绍了ORACLE数据库的新特性—分区管理,并用例子说明使用方法。 关键
- 首先来描述下环境,在机器上有很多个JAVA程序,我们在每个JAVA程序里都配置了一个启动|停止|重启的脚本举个例子:我们现在要同时运行这些脚
- 前言博主学习python有个几年了,对于python的掌握越来越深,很多时候,希望自己能掌握python越来越多的知识,但是,也意识很多时候
- 1983年1月19日,苹果公司发布乔布斯领导研制的新一代电脑Lisa,当时Lisa电脑的设计人员就认为,必须将立即执行的命令和需要用户附加输
- 首先说明一下,在python中是没有&&及||这两个运算符的,取而代之的是英文and和or。其他运算符没有变动。接着重点要说
- 首先你要确定错误的原因: 让IE显示详细的出错信息: 菜单--工具--Internet选项--高级--显示友好的HTTP错误信息,去掉这个选
- 作用: 构建一些简单的SQL语句,结合在提交表单时使用,可以较方便<%@LANGUAGE="VBSCRIPT&
- 可匹配结构:今天~前天, 几天前, 分钟秒前等 | 2017-1-4 12:10 | 2017/1/4 12:10 | 2018年4月2日
- 写过PHP+MySQL的程序员都知道有时间差,UNIX时间戳和格式化日期是我们常打交道的两个时间表示形式,Unix时间戳存储、处理方便,但是
- 方法1:pythonw xxx.py方法2:将.py改成.pyw (这个其实就是使用脚本解析程序pythonw.exe)跟 python.e
- 在刚进公司的时候,要写一个需求,使用django的admin站点管理,实现一个二级联动的功能,因为要用到django自带的页面,因为不是自定
- 01、介绍在编程语言中,字符串是一种重要的数据结构。在 Golang 语言中,因为字符串只能被访问,不能被修改,所以,如果我们在 Golan
- 看一个例子d={'test':1}d_test=dd_test['test']=2print d如果你在命令
- python函数的参数类型和返回类型默认为int。如果需要传递一个float值给dll,那么需要指定参数的类型。如果需要返回一个flaot值
- 需求:我要查询百度域名的到期时间或者开始时间思路分析:如果在linux系统中直接使用下面命令即可:echo | openssl s_clie
- 所谓定时器,是指间隔特定时间执行特定任务的机制。几乎所有的编程语言,都有定时器的实现。比如,Java有util.Timer和util.Tim
- BP神经网络是最简单的神经网络模型了,三层能够模拟非线性函数效果。难点:如何确定初始化参数?如何确定隐含层节点数量?迭代多少次?如何更快收敛
- UNIX时间戳转换为日期用函数FROM_UNIXTIME()select FROM_UNIXTIME(1156219870);日期