Python api构建tensorrt加速模型的步骤详解
作者:居然c 发布时间:2022-03-01 17:21:19
一、创建TensorRT有以下几个步骤:
1.用TensorRT中network模块定义网络模型
2.调用TensorRT构建器从网络创建优化的运行时引擎
3.采用序列化和反序列化操作以便在运行时快速重建
4.将数据喂入engine中进行推理
二、Python api和C++ api在实现网络加速有什么区别?
个人看法
1.python比c++更容易读并且已经有很多包装很好的科学运算库(numpy,scikit等),
2.c++是接近硬件的语言,运行速度比python快很多很多,因为python是解释性语言c++是编译型语言
三、构建TensorRT加速模型
3.1 加载tensorRT
1.import tensorrt as trt
2.为tensorrt实现日志报错接口方便报错,在下面的代码我们只允许警告和错误消息才打印,TensorRT中包含一个简单的日志记录器Python绑定。
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
3.2 创建网络
简单来说就是用tensorrt的语言来构建模型,如果自己构建的话,主要是灵活但是工作量so large,一般还是用tensorrt parser来构建
(1)Caffe框架的模型可以直接用tensorrt内部解释器构建
(2)除caffe,TF模型以外其他框架,先转成ONNX通用格式,再用ONNX parser来解析
(3)TF可以直接通过tensorrt内部的UFF包来构建,但是tensorrt uff包中并支持所有算子
(4)自己将wts放入自己构建的模型中,工作量so large,但是很灵活。
3.3 ONNX构建engine
因为博主用的ONNXparser来构建engine的,下面就介绍以下ONNX构建engine,步骤如下:
(1)导入tensorrt
import tensorrt as trt
(2)创建builder,network和相应模型的解释器,这里是onnxparser
EXPLICIT_BATCH = 1 << (int)
(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
with builder = trt.Builder(TRT_LOGGER) as builder,
builder.create_network(EXPLICIT_BATCH) as network, trt.OnnxParser(network,
TRT_LOGGER) as parser:
with open(model_path, 'rb') as model:
parser.parse(model.read())
这个代码的主要意思是,构建报错日志,创建build,network和onnxparser,然后用parser读取onnx权重文件。
3.3.1 builder介绍
builder功能之一是搜索cuda内核目录,找到最快的cuda以求获得最快的实现,因此有必要使用相同的GPU进行构建(相同的操作,算子进行融合,减少IO操作),engine就是在此基础上运行的,builder还可以控制网络以什么精度运行(FP32,FP16,INT8),还有两个特别重要的属性是最大批处理大小和最大工作空间大小。
builder.max_batch_size = max_batch_size
builder.max_workspace_size = 1 << 20
3.3.2序列化模型
序列化和反序列化模型的主要是因为network和定义创建engine很耗时,因此可以通过序列化一次并在推理时反序列化一次来避免每次应用程序重新运行时重新构建引擎。
note:序列化引擎不能跨平台或TensorRT版本移植。引擎是特定于它们所构建的GPU模型(除了平台和TensorRT版本)
代码如下:
#序列化模型到模型流
serialized_engine = engine.serialize()
#反序列化模型流去执行推理,反序列化需要创建一个运行时对象
with trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(serialized_engine)
#也可以将序列化模型write
with open(“sample.engine”, “wb”) as f:
f.write(engine.serialize())
#然后再读出来进行反序列化
with open(“sample.engine”, “rb”) as f, trt.Runtime(TRT_LOGGER) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
3.3.3执行推理过程
note:下面过程的前提是已经创建好了engine
# 为输入和输出分配一些主机和设备缓冲区:
#确定尺寸并创建页面锁定内存缓冲区
h_input = cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(0)),dtype=np.float32)
h_output =cuda.pagelocked_empty(trt.volume(engine.get_binding_shape(1)),dtype=np.float32)
#为输入和输出分配设备内存
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)
#创建一个流,在其中复制输入/输出并运行推断
stream = cuda.Stream()
# 创建一些空间来存储中间激活值,因为engine保存了network定义和训练时的参数,这些都是构建的上下文执行的。
with engine.create_execution_context() as context:
# 输入数据传入GPU
cuda.memcpy_htod_async(d_input, h_input, stream)
# 执行推理.
context.execute_async(bindings=[int(d_input), int(d_output)],
stream_handle=stream.handle)
# 将推理后的预测结果从GPU上返回.
cuda.memcpy_dtoh_async(h_output, d_output, stream)
# 同步流
stream.synchronize()
# 返回主机输出
return h_output
note:一个engine可以有多个执行上下文,允许一组权值用于多个重叠推理任务。例如,可以使用一个引擎和一个上下文在并行CUDA流中处理图像。每个上下文将在与引擎相同的GPU上创建。
来源:https://blog.csdn.net/weixin_45074568/article/details/120000583
猜你喜欢
- 近日,2018年最具就业前景的7大编程语言排行榜出炉了。这次的编程语言排行榜是由CodingDojo(编码道场)发布。在此次的最有“钱”途的
- 本教程使用python来生成随机漫步数据,再使用matplotlib将数据呈现出来开发环境操作系统: Windows10 IDE: Pych
- 问题:windows环境下新建或编辑文本文件,保存时会在头部加上BOM。使用ftp上传到linux下,在执行时第一行即报错。以下方法可以去除
- 一、什么是requirements.txt文件及作用requirements.txt 文件是项目的依赖包及其对应版本号的信息列表,即记载你这
- 这十则CSS技巧汇编于网络,作为老手已经司空见惯了,也没有什么新意,但温故而知新,或许阅读一遍也有一定的启发,本文主要面对CSS新手朋友,有
- 这篇文章主要介绍了如何使用Python抓取网页tag操作,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 前言: 这篇文章主要介绍RMAN的常用方法,其中包含了作者一些自己的经验,里面的实验也基本全在WIN 2K和ORACLE 8.1.6环境下测
- 花了两周时间,利用工作间隙时间,开发了一个基于Django的项目任务管理Web应用。项目计划的实时动态,可以方便地被项目成员查看(^_^又重
- 说完了理论,我们来做点实事。这篇文章将介绍使用 Javascript 实现的动画组件。下面记录下当时编写这个组件的考虑的些问题,对技术细节感
- 1.API接口:hello world 案例from flask import Flaskfrom flask_restful import
- 什么是爬虫?网络爬虫(又被称为网页蜘蛛,网络机器人,在FOAF社区中间,更经常的称为网页追逐者),是一种按照一定的规则,自动地抓取万维网信息
- 所谓类属性的延迟计算就是将类的属性定义成一个property,只在访问的时候才会计算,而且一旦被访问后,结果将会被缓存起来,不用每次都计算。
- 【原文地址】Tip/Trick: Url Rewriting with ASP.NET 【原文发表日期】 Monday, February
- 数据可以帮助我们描述这个世界、阐释自己的想法和展示自己的成果,但如果只有单调乏味的文本和数字,我们却往往能难抓住观众的眼球。而很多时候,一张
- 设置MySQL数据同步(单向&双向)由于公司的业务需求,需要网通和电信的数据同步,就做了个MySQL的双向同步,记下过程,以后用得到
- 本文实例讲述了Python实现的KMeans聚类算法。分享给大家供大家参考,具体如下:菜鸟一枚,编程初学者,最近想使用Python3实现几个
- 一维插值插值不同于拟合。插值函数经过样本点,拟合函数一般基于最小二乘法尽量靠近所有样本点穿过。常见插值方法有拉格朗日插值法、分段插值法、样条
- 问题描述:使用 SQL 2005 w/ SP2 的汇出汇入精灵将数据从 Access 汇入到 SQL2005 发生了错误,但使用在SQL 2
- 一、先来看看Python星空图代码绘制成品1 两个人的星空星空下,欲执子之手,相倚长青树。看皎洁月色,闻乡间气息,赏佳人芳心。2 明月相伴的
- 本文实例讲述了pytorch制作自己的LMDB数据操作。分享给大家供大家参考,具体如下:前言记录下pytorch里如何使用lmdb的code