Python api构建tensorrt加速模型的步骤详解
作者:居然c 发布时间:2022-03-01 17:21:19
一、创建TensorRT有以下几个步骤:
1.用TensorRT中network模块定义网络模型
2.调用TensorRT构建器从网络创建优化的运行时引擎
3.采用序列化和反序列化操作以便在运行时快速重建
4.将数据喂入engine中进行推理
二、Python api和C++ api在实现网络加速有什么区别?
个人看法
1.python比c++更容易读并且已经有很多包装很好的科学运算库(numpy,scikit等),
2.c++是接近硬件的语言,运行速度比python快很多很多,因为python是解释性语言c++是编译型语言
三、构建TensorRT加速模型
3.1 加载tensorRT
1.import tensorrt as trt
2.为tensorrt实现日志报错接口方便报错,在下面的代码我们只允许警告和错误消息才打印,TensorRT中包含一个简单的日志记录器Python绑定。
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
3.2 创建网络
简单来说就是用tensorrt的语言来构建模型,如果自己构建的话,主要是灵活但是工作量so large,一般还是用tensorrt parser来构建
(1)Caffe框架的模型可以直接用tensorrt内部解释器构建
(2)除caffe,TF模型以外其他框架,先转成ONNX通用格式,再用ONNX parser来解析
(3)TF可以直接通过tensorrt内部的UFF包来构建,但是tensorrt uff包中并支持所有算子
(4)自己将wts放入自己构建的模型中,工作量so large,但是很灵活。
3.3 ONNX构建engine
因为博主用的ONNXparser来构建engine的,下面就介绍以下ONNX构建engine,步骤如下:
(1)导入tensorrt
import tensorrt as trt
(2)创建builder,network和相应模型的解释器,这里是onnxparser
EXPLICIT_BATCH = 1 << (int)
(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with builder = trt.Builder(TRT_LOGGER) as builder,
builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network,
TRT_LOGGER) as parser:
with open(model_path, 'rb') as model:
parser.parse(model.read())
这个代码的主要意思是,构建报错日志,创建build,network和onnxparser,然后用parser读取onnx权重文件。
3.3.1 builder介绍
builder功能之一是搜索cuda内核目录,找到最快的cuda以求获得最快的实现,因此有必要使用相同的GPU进行构建(相同的操作,算子进行融合,减少IO操作),engine就是在此基础上运行的,builder还可以控制网络以什么精度运行(FP32,FP16,INT8),还有两个特别重要的属性是最大批处理大小和最大工作空间大小。
builder.max_batch_size = max_batch_size
builder.max_workspace_size = 1 << 20
3.3.2序列化模型
序列化和反序列化模型的主要是因为network和定义创建engine很耗时,因此可以通过序列化一次并在推理时反序列化一次来避免每次应用程序重新运行时重新构建引擎。
note:序列化引擎不能跨平台或TensorRT版本移植。引擎是特定于它们所构建的GPU模型(除了平台和TensorRT版本)
代码如下:
#序列化模型到模型流
serialized_engine = engine.serialize()
#反序列化模型流去执行推理,反序列化需要创建一个运行时对象
with trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(serialized_engine)
#也可以将序列化模型write
with open(“sample.engine”, “wb”) as f:
f.write(engine.serialize())
#然后再读出来进行反序列化
with open(“sample.engine”, “rb”) as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
3.3.3执行推理过程
note:下面过程的前提是已经创建好了engine
# 为输入和输出分配一些主机和设备缓冲区:
#确定尺寸并创建页面锁定内存缓冲区
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),dtype=np.float32)
h_output =cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),dtype=np.float32)
#为输入和输出分配设备内存
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)
#创建一个流,在其中复制输入/输出并运行推断
stream = cuda.Stream()
# 创建一些空间来存储中间激活值,因为engine保存了network定义和训练时的参数,这些都是构建的上下文执行的。
with engine.create_execution_context() as context:
# 输入数据传入GPU
cuda.memcpy_htod_async(d_input, h_input, stream)
# 执行推理.
context.execute_async(bindings=[int(d_input), int(d_output)],
stream_handle=stream.handle)
# 将推理后的预测结果从GPU上返回.
cuda.memcpy_dtoh_async(h_output, d_output, stream)
# 同步流
stream.synchronize()
# 返回主机输出
return h_output
note:一个engine可以有多个执行上下文,允许一组权值用于多个重叠推理任务。例如,可以使用一个引擎和一个上下文在并行CUDA流中处理图像。每个上下文将在与引擎相同的GPU上创建。
来源:https://blog.csdn.net/weixin_45074568/article/details/120000583


猜你喜欢
- 前言:线程是指进程内的一个执行单元,也是进程内的可调度实体.与进程的区别:(1) 地址空间:进程内的一个执行单元;进程至少有一个线程;它们共
- 目录range函数的使用第一种创建方式第二种创建方式第三种创建方式判断指定的数有没有在当前序列中循环结构总结range函数的使用作为循环遍历
- JavaScript: <script type="text/javascript"> var level1
- 迷宫生成1.随机PRIM思路:先让迷宫中全都是墙,不断从列表(最初只含有一个启始单元格)中选取一个单元格标记为通路,将其周围(上下左右)未访
- 1.安装好JDK下载并安装好jdk-12.0.1_windows-x64_bin.exe,配置环境变量:新建系统变量JAVA_HOME,值为
- XML是一个精简的SGML,它将SGML的丰富功能与HTML的易用性结合到Web的应用中。XML保留了SGML的可扩展功能,这使XML从根本
- 本文涉及:Windows操作系统,Python,PyQt5,Qt Designer,PyCharm一、自适应原理 &
- 本篇博文主要介绍在Python3中如何构造含参构造函数样例如下:class MyOdlHttp: username = '
- 下面展示了图像的加密和解密过程(左边是输入图像,中间是加密后的结果,右边是解密后的图像):1、加密算法要求(1)加密算法必须是可逆的,拥有配
- 命令行下能正常登陆MYSQL,navicat能正常连接MySQL,但是IDEA连接不上MySQL,emmm,什么情况。。。看了一下错误提示:
- sql语句查询数据库中的表名/列名/主键/自动增长值 ----查询数据库中用户创建的表 ----jsj01 为数据库名 select nam
- 本文实例讲述了Python实现的统计文章单词次数功能。分享给大家供大家参考,具体如下:题目是这样的:你有一个目录,放了你一个月的日记,都是
- 切片:方便截取list、tuple、字符串部分索引的内容正序切片语法:dlist = doList[0:3]表示,从索引0开始取,直到索引3
- 本文实例讲述了python实现根据主机名字获得所有ip地址的方法。分享给大家供大家参考。具体实现方法如下:# -*- coding: utf
- 本文实例讲述了PHP函数shuffle()取数组若干个随机元素的方法。分享给大家供大家参考,具体如下:有时候我们需要取数组中若干个随机元素(
- 问题:如何把具有相同字段的记录删除,只留下一条。 例如:表test里有id,name字段,如果有name相同的记录只留下一条,
- 目录1.部分转义字符2.slice 切片读取字符串3.调用split()方法分割字符串 ASCII字母4.与字母大小写有关方法5.搜索查找字
- v-model指令 所谓的“指令”其实就是扩展了HTML标签功能(属性)。先来一个组件,不用vue-model,正常父子通信<!--
- (5)SELECT (5-2) DISTINCT(5-3)TOP(<top_specification>)(5-1) <s
- 阅读上一篇:css基础教程属性篇 本篇主要介绍css对边框(border)的属性控制和链接(link)的伪类选择器.边框(border):