pytorch 使用半精度模型部署的操作
作者:treeswolf 发布时间:2022-04-17 21:33:36
背景
pytorch作为深度学习的计算框架正得到越来越多的应用.
我们除了在模型训练阶段应用外,最近也把pytorch应用在了部署上.
在部署时,为了减少计算量,可以考虑使用16位浮点模型,而训练时涉及到梯度计算,需要使用32位浮点,这种精度的不一致经过测试,模型性能下降有限,可以接受.
但是推断时计算量可以降低一半,同等计算资源下,并发度可提升近一倍
具体方法
在pytorch中,一般模型定义都继承torch.nn.Moudle,torch.nn.Module基类的half()方法会把所有参数转为16位浮点,所以在模型加载后,调用一下该方法即可达到模型切换的目的.接下来只需要在推断时把input的tensor切换为16位浮点即可
另外还有一个小的trick,在推理过程中模型输出的tensor自然会成为16位浮点,如果需要新创建tensor,最好调用已有tensor的new_zeros,new_full等方法而不是torch.zeros和torch.full,前者可以自动继承已有tensor的类型,这样就不需要到处增加代码判断是使用16位还是32位了,只需要针对input tensor切换.
补充:pytorch 使用amp.autocast半精度加速训练
准备工作
pytorch 1.6+
如何使用autocast?
根据官方提供的方法,
答案就是autocast + GradScaler。
如何在PyTorch中使用自动混合精度?
答案:autocast + GradScaler。
1.autocast
正如前文所说,需要使用torch.cuda.amp模块中的autocast 类。使用也是非常简单的
from torch.cuda.amp import autocast as autocast
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# 前向过程(model + loss)开启 autocast
with autocast():
output = model(input)
loss = loss_fn(output, target)
# 反向传播在autocast上下文之外
loss.backward()
optimizer.step()
2.GradScaler
GradScaler就是梯度scaler模块,需要在训练最开始之前实例化一个GradScaler对象。
因此PyTorch中经典的AMP使用方式如下:
from torch.cuda.amp import autocast as autocast
# 创建model,默认是torch.FloatTensor
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
# 在训练最开始之前实例化一个GradScaler对象
scaler = GradScaler()
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# 前向过程(model + loss)开启 autocast
with autocast():
output = model(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
3.nn.DataParallel
单卡训练的话上面的代码已经够了,亲测在2080ti上能减少至少1/3的显存,至于速度。。。
要是想多卡跑的话仅仅这样还不够,会发现在forward里面的每个结果都还是float32的,怎么办?
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, input_data_c1):
with autocast():
# code
return
只要把forward里面的代码用autocast代码块方式运行就好啦!
自动进行autocast的操作
如下操作中tensor会被自动转化为半精度浮点型的torch.HalfTensor:
1、matmul
2、addbmm
3、addmm
4、addmv
5、addr
6、baddbmm
7、bmm
8、chain_matmul
9、conv1d
10、conv2d
11、conv3d
12、conv_transpose1d
13、conv_transpose2d
14、conv_transpose3d
15、linear
16、matmul
17、mm
18、mv
19、prelu
那么只有这些操作才能半精度吗?不是。其他操作比如rnn也可以进行半精度运行,但是需要自己手动,暂时没有提供自动的转换。
来源:https://blog.csdn.net/treeswolf/article/details/105748209
猜你喜欢
- #!/usr/bin/env python3# -*- coding: utf-8 -*-# File Name : gt1.py# Pur
- 主程序mainaddfunc.pyfrom flask import Flask, render_template, request, ur
- AJAX应用因为它们的表现力的丰富、更加互动和更加迅速的响应得到了赞扬声;这些优点都是通过使用XMLHttpRequest对象来动态的载入数
- 具体内容如下所示:参考案例:import turtled=0for i in range(4): turtle.fd(200)
- 我们知道现实中的数据通常是杂乱无章的,需要大量的预处理才能使用。Pandas 是应用最广泛的数据分析和处理库之一,它提供了多种对原始数据进行
- 说到排序,很多人可能第一想到的就是sorted,但是你可能不知道python中其实还有还就中方法哟,并且好多种场景下效率都会比sorted高
- 在Dreamweaver 4.0中,我们就已接触了模板与库的概念,知道它们是批量生成风格类似的网页的好工具。如今在Dreamweaver M
- 1、介绍在爬虫中经常会遇到验证码识别的问题,现在的验证码大多分计算验证码、滑块验证码、识图验证码、语音验证码等四种。本文就是识图验证码,识别
- js汉字简繁转换源代码:<html> <head> <title>汉字简繁转换工具_asp之家</
- 2D坐标系1 修改全部坐标颜色import matplotlib.pyplot as pltimport numpy as np#显示静态图
- 停止mysql服务(以管理员身份,在cmd命令行下运行) net stop mysql或者在服务中停止mysql服务。使用 mysqld –
- 我们一般采用photoshop等做图工具制作电视扫描线效果图片:首先做一个黑白相间的图案,然后用这个图案进行填充,再调整图层的模式或者透明度
- 介绍:细处着手,巧处用功。高手和菜鸟之间的差别就是:高手什么都知道,菜鸟知道一些。电脑小技巧收集最新奇招高招,让你轻松踏上高手之路。摘要:
- asp之家注:如果你学习过asp,并且在网络公司上过班,一定会接触到网购系统,网购系统可以说是一个典型的程序类型,而其中最重要,也是最关键的
- 这是一个网页设计中经常会用到的图片特效,实现多个图片之间的轮换,并分别带有连接。以前的代码只能适用于IE,在FF下始终没有得到很好的解决今天
- 1. 非 matlab v7.3 files 读写import scipy.io as sioimport numpy# matFile 读
- 问题你想在使用范围内执行某个代码片段,并且希望在执行后所有的结果都不可见。解决方案为了理解这个问题,先试试一个简单场景。首先,在全局命名空间
- 这篇文章主要介绍了Python matplotlib以日期为x轴作图代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的
- 画星星程序2-7-7主要使用turtle.forward前进操作和turtle.left左转操作在屏幕上画星星。#!/usr/bin/env
- 学习前言最近在学目标检测……SSD的源码好复杂……看