解决Pytorch半精度浮点型网络训练的问题
作者:阿刚的代码进阶之旅 发布时间:2021-10-13 17:56:45
用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:
1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()
2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可
3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。
另外,SGD算法对于半精度和全精度计算均没有问题。
还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。
对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。
但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。
将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。
补充:pytorch半精度,混合精度,单精度训练的区别amp.initialize
看代码吧~
mixed_precision = True
try: # Mixed precision training https://github.com/NVIDIA/apex
from apex import amp
except:
mixed_precision = False # not installed
model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=1)
为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。
文档地址是:https://nvidia.github.io/apex/index.html
该 工具 提供了三个功能,amp、parallel和normalization。由于目前该工具还是0.1版本,功能还是很基础的,在最后一个normalization功能中只提供了LayerNorm层的复现,实际上在后续的使用过程中会发现,出现问题最多的是pytorch的BN层。
第二个工具是pytorch的分布式训练的复现,在文档中描述的是和pytorch中的实现等价,在代码中可以选择任意一个使用,实际使用过程中发现,在使用混合精度训练时,使用Apex复现的parallel工具,能避免一些bug。
默认训练方式是 单精度float32
import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
out = model(img)
loss = LOSS(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
半精度 model(img.half())
import torch
model = torch.nn.Linear(D_in, D_out).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
out = model(img.half())
loss = LOSS(out, label)
loss.backward()
optimizer.step()
optimizer.zero_grad()
接下来是混合精度的实现,这里主要用到Apex的amp工具。
代码修改为:
加上这一句封装,
model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)
import torch
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
for img, label in dataloader:
out = model(img)
loss = LOSS(out, label)
# loss.backward()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
optimizer.zero_grad()
实际流程为:调用amp.initialize按照预定的opt_level对model和optimizer进行设置。在计算loss时使用amp.scale_loss进行回传。
需要注意以下几点:
在调用amp.initialize之前,模型需要放在GPU上,也就是需要调用cuda()或者to()。
在调用amp.initialize之前,模型不能调用任何分布式设置函数。
此时输入数据不需要在转换为半精度。
在使用混合精度进行计算时,最关键的参数是opt_level。他一共含有四种设置值:‘00',‘01',‘02',‘03'。实际上整个amp.initialize的输入参数很多:
但是在实际使用过程中发现,设置opt_level即可,这也是文档中例子的使用方法,甚至在不同的opt_level设置条件下,其他的参数会变成无效。(已知BUG:使用‘01'时设置keep_batchnorm_fp32的值会报错)
概括起来:
00相当于原始的单精度训练。01在大部分计算时采用半精度,但是所有的模型参数依然保持单精度,对于少数单精度较好的计算(如softmax)依然保持单精度。02相比于01,将模型参数也变为半精度。
03基本等于最开始实验的全半精度的运算。值得一提的是,不论在优化过程中,模型是否采用半精度,保存下来的模型均为单精度模型,能够保证模型在其他应用中的正常使用。这也是Apex的一大卖点。
在Pytorch中,BN层分为train和eval两种操作。
实现时若为单精度网络,会调用CUDNN进行计算加速。常规训练过程中BN层会被设为train。Apex优化了这种情况,通过设置keep_batchnorm_fp32参数,能够保证此时BN层使用CUDNN进行计算,达到最好的计算速度。
但是在一些fine tunning场景下,BN层会被设为eval(我的模型就是这种情况)。此时keep_batchnorm_fp32的设置并不起作用,训练会产生数据类型不正确的bug。此时需要人为的将所有BN层设置为半精度,这样将不能使用CUDNN加速。
一个设置的参考代码如下:
def fix_bn(m):
classname = m.__class__.__name__
if classname.find('BatchNorm') != -1:
m.eval().half()
model.apply(fix_bn)
实际测试下来,最后的模型准确度上感觉差别不大,可能有轻微下降;时间上变化不大,这可能会因不同的模型有差别;显存开销上确实有很大的降低。
来源:https://www.cnblogs.com/yanxingang/p/10148712.html
猜你喜欢
- 本文实例为大家分享了Python实现简单层次聚类算法,以及可视化,供大家参考,具体内容如下基本的算法思路就是:把当前组间距离最小的两组合并成
- 装饰器一、介绍器:代表函数的意思。装饰器本质就是是函数功能:装饰其他函数,就是为其他函数添加附加功能 被装饰函数感受不到装饰器的存
- 正则表达式是处理字符串的强大工具。作为一个概念而言,正则表达式对于Python来说并不是独有的。但是,Python中的正则表达式在实际使用过
- 如果查询结果很多,服务器解释你的ASP script将花费大量的时间,因为有许多的Response.Write语句要处理. 如果你将输出的全
- matplotlib是功能十分强大的绘制二维图形的Python模块,它用Python语言实现了MATLAB画图函数的易用性,同时又有非常强大
- #/usr/bin/env python#-*- coding:utf-8 -*-"""1.解析 cronta
- 在实际工作中,无论是对数据库系统(DBMS),还是对数据库应用系统(DBAS),查询优化一直是一个热门话题。一个成功的数据库应用系统的开发,
- pip install psycopg2出现错误:Looking in indexes: https://pypi.tuna.tsinghu
- 1.设置Headers有些网站不会同意程序直接用上面的方式进行访问,如果识别有问题,那么站点根本不会响应,所以为了完全模拟浏览器的工作,我们
- 对于注入而言,错误提示是极其重要。所谓错误提示是指和正确页面不同的结果反馈,高手是很重视这个一点的,这对于注入点的精准判断至关重要。本问讨论
- 很多朋友说JavaScript的decodeURI函数也可以实现,但有bug所有呢,下面看下下面的函数,经过测试使用暂时没什么问题,我在之前
- 前言前面我们已经介绍了 python面向对象入门教程之从代码复用开始(一) ,这篇文章主要介绍的是关于Python面向对象之设置对
- 文件提交页面既已生成,下面任务就很明确了:将提交的文件内容保存到服务器上。 下面我们用两种方法来实现这个功能: 1. 用 PHP 来保存:
- 1、 Python中 sys.argv的用法解释:sys.argv可以让python脚本从程序外部获取参数,sys.argv是一个列表,可用
- 春节前在蓝色理想上发了个“雅虎口碑招聘前端工程师 ”的启事,节后收到很多简历,加之HR通过专业招聘网站得到的简历和朋友同事推荐的简历,数量上
- 本文介绍ThinkPHP的limit()方法的用法。limit方法可以用于对数据库操作的结果进行取指定范围的条数。即相当于是在mysql查询
- 项目简介鉴于项目保密的需要,不便透露太多项目的信息,因此,简单介绍一下项目存在的难点:海量数据:项目是对CSV文件中的数据进行处理,而特点是
- 公司客户在使用网站后台编辑添加修改内容时,经常是直接从word文档里复制内容到编辑器里后就提交。结果是在内容显示页面上是五花八门的样式,有时
- 判断服务器是否安装了某种asp组件,比较常用的代码如下:代码如下:<% '功能:检查是否存在系统组件或组件是否安装成功
- 问:把数据从MySQL迁移到Oracle需要注意些什么?答:以下是MySQL迁到Oracle需要掌握的注意事项,希望对你有所帮助。1.自动增