Pytorch实现常用乘法算子TensorRT的示例代码
作者:极智视界 发布时间:2021-08-17 17:49:47
本文介绍一下 Pytorch 中常用乘法的 TensorRT 实现。
pytorch 用于训练,TensorRT 用于推理是很多 AI 应用开发的标配。大家往往更加熟悉 pytorch 的算子,而不太熟悉 TensorRT 的算子,这里拿比较常用的乘法运算在两种框架下的实现做一个对比,可能会有更加直观一些的认识。
1.乘法运算总览
先把 pytorch 中的一些常用的乘法运算进行一个总览:
torch.mm:用于两个矩阵 (不包括向量) 的乘法,如维度 (m, n) 的矩阵乘以维度 (n, p) 的矩阵;
torch.bmm:用于带 batch 的三维向量的乘法,如维度 (b, m, n) 的矩阵乘以维度 (b, n, p) 的矩阵;
torch.mul:用于同维度矩阵的逐像素点相乘,也即点乘,如维度 (m, n) 的矩阵点乘维度 (m, n) 的矩阵。该方法支持广播,也即支持矩阵和元素点乘;
torch.mv:用于矩阵和向量的乘法,矩阵在前,向量在后,如维度 (m, n) 的矩阵乘以维度为 (n) 的向量,输出维度为 (m);
torch.matmul:用于两个张量相乘,或矩阵与向量乘法,作用包含 torch.mm、torch.bmm、torch.mv;
@:作用相当于 torch.matmul;
*:作用相当于 torch.mul;
如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。
2.乘法算子实现
2.1矩阵乘算子实现
先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):
>>> import torch
>>> # torch.mm
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99, 88)
>>> c = torch.mm(a, b)
>>> c.shape
torch.size([66, 88])
>>>
>>> # torch.bmm
>>> a = torch.randn(3, 66, 99)
>>> b = torch.randn(3, 99, 77)
>>> c = torch.bmm(a, b)
>>> c.shape
torch.size([3, 66, 77])
>>>
>>> # torch.mv
>>> a = torch.randn(66, 99)
>>> b = torch.randn(99)
>>> c = torch.mv(a, b)
>>> c.shape
torch.size([66])
>>>
>>> # torch.matmul
>>> a = torch.randn(32, 3, 66, 99)
>>> b = torch.randn(32, 3, 99, 55)
>>> c = torch.matmul(a, b)
>>> c.shape
torch.size([32, 3, 66, 55])
>>>
>>> # @
>>> d = a @ b
>>> d.shape
torch.size([32, 3, 66, 55])
来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply
方法覆盖,对应 torch.matmul,先来看该方法的定义:
//!
//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}
可以看到这个方法有四个传参,对应两个张量和其 operation
。来看这个算子在 TensorRT 中怎么添加:
// 构造张量 Tensor0
nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 构造张量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);
// 添加矩阵乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);
// 获取输出
matmulOutput = Matmul_layer->getOputput(0);
2.2点乘算子实现
再来看看点乘的 pytorch 的实现 (以下实现在终端):
>>> import torch
>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])
来看 TensorRT 的实现,以上乘法都可使用 addScale
方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:
//!
//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//! This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//! and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
return mImpl->addScale(input, mode, shift, scale, power);
}
可以看到有三个模式:
kUNIFORM:weights 为一个值,对应张量乘一个元素;
kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;
kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;
再来看这个算子在 TensorRT 中怎么添加:
// 构造张量 input
nvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);
// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;
// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };
// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行
// 添加张量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);
// 获取输出
scaleOutput = Scale_layer->getOputput(0);
有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。
来源:https://blog.csdn.net/weixin_42405819/article/details/125070931


猜你喜欢
- 问题:我想每日从数据库里导出一些数据,内容基本上都是一样的,只是时间不同,比如导出一张表wjzcreate table wjz(id int
- 例如“I am a boy”,逆序排放后为“boy a am I”所有单词之间用一个空格隔开,语句中除了英文字母外,不再包含其他字符。lis
- 现在向大家介绍mysql命令行下,从数据库的建立到表数据的删除全过程,希望对大家有所帮助。登陆mysql打cmd命令终端,如果已经添加了my
- 看看这个logo,有些像python的小蛇吧 。这次介绍的数据库codernityDB是纯python开发的。先前用了下tinyDB这个本地
- 在写移动端页面的时候,弹出遮罩层后,我们仍然可以滚动页面。vue中提供 @touchmove.prevent 方法可以完美解决这个问题<
- 本文实例为大家分享了python递归全排列的实现方法,供大家参考,具体内容如下排列:从n个元素中任取m个元素,并按照一定的顺序进行排列,称为
- 什么是Python元类?Python元类是与Python的面向对象编程概念相关的高级功能之一。它确定类的行为,并进一步帮助其修改。用Pyth
- Go反射的实现和 interface 和 unsafe.Pointer 密切相关。如果对golang的 interface 底层实现还没有理
- 最近有个朋友问我关于 Node.js 下使用 ECDSA 的问题,主要是使用 Node.js 的 Crypto 模块无法校验网络传输过来的签
- 懒加载是一种编程范式,它推迟加载操作,直到不得不这样做。通常,当操作开销很大,需要耗费大量时间或空间时,惰性求值是首选实现。例如,在 Pyt
- 使用matplotlib绘图时,在弹出的窗口中默认是有工具栏的,那么这些工具栏是如何定义的呢?工具栏的三种模式matplotlib的基础配置
- 一、前言CRITIC权重法是一种比熵权法和标准离差法更好的客观赋权法:它是基于评价指标的对比强度和指标之间的冲突性来综合衡量指标的客观权重。
- 1. 欧几里德算法欧几里德算法又称辗转相除法, 用于计算两个整数a, b的最大公约数。其计算原理依赖于下面的定理:定理: gcd(a, b)
- 装饰器这东西我看了一会儿才明白,在函数外面套了一层函数,感觉和java里的aop功能很像;写了2个装饰器日志的例子,第一个是不带参数的装饰器
- 我页面上有控制了只能输入数字的控件,禁止了输入法切换的,但是搜狗的云输入却控制不了,有没有办法在页面里面禁止它运行啊?发现这玩意儿真的很讨厌
- 折腾好半天的数据库连接,由于之前未安装 pip ,而且自己用的python 版本为3.6. 只能用 pymysql 来连接数据库,下边 简单
- 在生产环境上,一般会使用比较健壮的Web服务器,如Apache来运行我们的应用。如果我们的Web应用是采用Python开发,而且符合WSGI
- 今天先聊一聊在windows/mac iOS系统下用venv搭建python轻量级虚拟环境的问题。使用venv搭建的虚拟环境同virtual
- javascript 跨域问题以及解决办法什么是跨域问题?跨域这个问题是由于浏览器的同源策略引起的,请求的URL地址,必须与浏览器的URL是
- PyCharm是Python著名的Python集成开发环境(IDE)conda有Miniconda和Anaconda,前者应该是类似最小化版