pytorch 使用半精度模型部署的操作
作者:treeswolf 发布时间:2022-04-17 21:33:36
背景
pytorch作为深度学习的计算框架正得到越来越多的应用.
我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.
在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.
但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍
具体方法
在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可
另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.
补充:pytorch 使用amp.autocast半精度加速训练
准备工作
pytorch 1.6+
如何使用autocast?
根据官方提供的方法,
答案就是autocast + GradScaler。
如何在PyTorch中使用自动混合精度?
答案:autocast + GradScaler。
1.autocast
正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的
from torch.cuda.amp import autocast as autocast
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# 前向过程(model + loss)开启 autocast
with autocast():
output = model(input)
loss = loss_fn(output, target)
# 反向传播在autocast上下文之外
loss.backward()
optimizer.step()
2.GradScaler
GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。
因此PyTorch中经典的AMP使用方式如下:
from torch.cuda.amp import autocast as autocast
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 前向过程(model + loss)开启 autocast
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
3.nn.DataParallel
单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。
要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, input_data_c1):
with autocast():
# code
return
只要把forward里面的代码用autocast代码块方式运行就好啦!
自动进行autocast的操作
如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:
1、matmul
2、addbmm
3、addmm
4、addmv
5、addr
6、baddbmm
7、bmm
8、chain_matmul
9、conv1d
10、conv2d
11、conv3d
12、conv_transpose1d
13、conv_transpose2d
14、conv_transpose3d
15、linear
16、matmul
17、mm
18、mv
19、prelu
那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。
来源:https://blog.csdn.net/treeswolf/article/details/105748209


猜你喜欢
- 新下载了mysql,口令为空,如何修改root口令:首先登陆mysqluse mysql;update user set password=
- 索引 经常要查询的语句,则给它建一个索引 表连接 select T_Oders as o join T_Customers as C on
- Function closeUBB(strContent) '*************************
- 在统计学和数据分析领域中,我们常常需要比较两个或多个样本数据之间的差异。而带置信区间的折线图则是一种直观且常用的展示数据差异的方式。在这篇文
- 该语句的作用是:启用或禁用错误处理程序。一般用法如下:On Error Resume NextOn Error GoTo 0如果在您的代码中
- TNS是Oracle Net的一部分,是专门用来管理和配置Oracle数据库和客户端连接的一个工具,在大多数情况下客户端和数据库要通讯,就必
- 本文实例讲述了django框架模型层功能、组成与用法。分享给大家供大家参考,具体如下:Django models是Django框架自定义的一
- 对设计“以人为本”和“绿色设计”两个观点的反思——兼与设计界同仁商榷Reflection of Two Views: “People-ori
- 1.字典文本特征提取 DictVectorizer()1.1 one-hot编码创建一个字典,观察如下数据形式的变化:import pand
- 最近遇到一个问题,就是获取表单中的日期往后台通过json方式传的时候,遇到Date.parse(str)函数在ff下报错: NAN 找了些资
- 一直用的TensorFlow(keras)来完成一些工作,因许多论文中的模型用pytorch来实现,代码看不懂实在是不太应该。正好趁此假期,
- 如果一张表的数据达到上百万条,用游标的方法来删除简直是个噩梦,因为它会执行相当长的一段时间…… 开发人员的噩梦——删
- Python用input输入列表的方法使用input输入数据时,使用逗号隔开列表的每一项,再使用ast.literal_eval()方法转成
- 代码如下所示:import osimport requestsimport datetimefrom Crypto.Cipher impor
- 利用python,可以实现填充网页表单,从而自动登录WEB门户。(注意:以下内容只针对python3)环境准备:(1)安装python (2
- 一,啥是Block Formatting Context当涉及到可视化布局的时候,Block Formatting Context提供了一个
- 目录前言全局参数持久化写在最后总结参考文档:前言自从 2018 年发布第一版 MySQL 8.0.11 正式版至今,MySQL 版本已经更新
- MySQL的默认的调度策略可用总结如下:· 写入操作优先于读取操作。· 对某张数据表的写入操作某一时刻只能发生一次,写入请求按照它们到达的次
- 在利用Python进行系统管理的时候,特别是同时操作多个文件目录,或者远程控制多台主机,并行操作可以节约大量的时间。当 * 作对象数目不大时,
- 1、冒泡排序它反复访问要排序的元素列,并依次比较两个相邻的元素。如果顺序(如从大到小)错了,就交换它们。访问元素的工作是反复进行,直到没有相