在PyTorch中使用标签平滑正则化的问题
作者:deephub 发布时间:2021-12-26 17:46:54
什么是标签平滑?在PyTorch中如何去使用它?
在训练深度学习模型的过程中,过拟合和概率校准(probability calibration)是两个常见的问题。一方面,正则化技术可以解决过拟合问题,其中较为常见的方法有将权重调小,迭代提前停止以及丢弃一些权重等。另一方面,Platt标度法和isotonic regression法能够对模型进行校准。但是有没有一种方法可以同时解决过拟合和模型过度自信呢?
标签平滑也许可以。它是一种去改变目标变量的正则化技术,能使模型的预测结果不再仅为一个确定值。标签平滑之所以被看作是一种正则化技术,是因为它可以防止输入到softmax函数的最大logits值变得特别大,从而使得分类模型变得更加准确。
在这篇文章中,我们定义了标签平滑化,在测试过程中我们将它应用到交叉熵损失函数中。
标签平滑?
假设这里有一个多分类问题,在这个问题中,目标变量通常是一个one-hot向量,即当处于正确分类时结果为1,否则结果是0。
标签平滑改变了目标向量的最小值,使它为ε。因此,当模型进行分类时,其结果不再仅是1或0,而是我们所要求的1-ε和ε,从而带标签平滑的交叉熵损失函数为如下公式。
在这个公式中,ce(x)表示x的标准交叉熵损失函数,例如:-log(p(x)),ε是一个非常小的正数,i表示对应的正确分类,N为所有分类的数量。
直观上看,标记平滑限制了正确类的logit值,并使得它更接近于其他类的logit值。从而在一定程度上,它被当作为一种正则化技术和一种对抗模型过度自信的方法。
PyTorch中的使用
在PyTorch中,带标签平滑的交叉熵损失函数实现起来非常简单。首先,让我们使用一个辅助函数来计算两个值之间的线性组合。
deflinear_combination(x, y, epsilon):return epsilon*x + (1-epsilon)*y
下一步,我们使用PyTorch
中一个全新的损失函数:nn.Module
.
import torch.nn.functional as F
defreduce_loss(loss, reduction='mean'):return loss.mean() if reduction=='mean'else loss.sum() if reduction=='sum'else loss
classLabelSmoothingCrossEntropy(nn.Module):def__init__(self, epsilon:float=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
defforward(self, preds, target):
n = preds.size()[-1]
log_preds = F.log_softmax(preds, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
return linear_combination(loss/n, nll, self.epsilon)
我们现在可以在代码中删除这个类。对于这个例子,我们使用标准的fast.ai pets example
.
from fastai.vision import *
from fastai.metrics import error_rate
# prepare the data
path = untar_data(URLs.PETS)
path_img = path/'images'
fnames = get_image_files(path_img)
bs = 64
np.random.seed(2)
pat = r'/([^/]+)_\d+.jpg$'
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs) \
.normalize(imagenet_stats)
# train the model
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.loss_func = LabelSmoothingCrossEntropy()
learn.fit_one_cycle(4)
最后将数据转换成模型可以使用的格式,选择ResNet架构并以带标签平滑的交叉熵损失函数作为优化目标。经过四轮循环后,其结果如下
我们所得结果的错误率仅为7.5%,这对于10行左右的代码来说是完全可以接受的,并且在模型中大多数参数还都选择的是默认设置。
因此,在模型中还有许多参数可以进行调整,从而使得模型的表现性能更好,例如:可以使用不同的优化器、超参数、模型架构等。
结论
在这篇文章中,我们了解了什么是标签平滑以及什么时候去使用它,并且我们还知道了如何在PyTorch中实现它。之后,我们训练了一个先进的计算机视觉模型,仅使用十行代码就识别出了不同品种的猫和狗。
模型正则化和模型校准是两个重要的概念。若想成为一个深度学习的资深玩家,就应该好好地去理解这些能够对抗过拟合和模型过度自信的工具。
作者简介: Dimitris Poulopoulos,是BigDataStack的一名机器学习研究员,同时也是希腊Piraeus大学的博士。曾为欧盟委员会、欧盟统计局、国际货币基金组织、欧洲央行等客户设计过与AI相关的软件。
来源:https://blog.csdn.net/m0_46510245/article/details/105267655


猜你喜欢
- 打开终端输入以下命令 --> 回车 -->输入密码 -->回车 -->结束:sudo rm -rf /usr/loc
- SQL Server数据库快捷键:书签:清除所有书签。 CTRL-SHIFT-F2书签:插入或删除书签(切换)。 CTRL+F2书签:移动到
- 我遇到的一个小需求,就是希望通过判断pandas dataframe中一列的值在两个条件范围(比如下面代码中所描述的逻辑,取小于u-3ε和大
- 今天在帮前端准备数据的时候,需要把数据格式转成json格式,说实话,涉及到中文有时候真的是很蛋疼,除非对Python的编码规则比较了解,不然
- 本文实例讲述了Python 25行代码实现的RSA算法。分享给大家供大家参考,具体如下:网络上很多关于RSA算法的原理介绍,但是翻来翻去就是
- 目录1. 关联规则1.1 基本概念1.2 关联规则Apriori算法2. mlxtend实战关联规则2.1 安装2.2 简单的例子3. 总结
- 创建随机数 ①自JavaScript产生后,好多浏览器中都有内置的随机数发生方法。例如: var number = Math.random(
- 一、异常检测简介异常检测是通过数据挖掘方法发现与数据集分布不一致的异常数据,也被称为离群点、异常值检测等等。1.1 异常检测适用的场景异常检
- <?php /******************************************** *&nb
- 关于高性能的分布式内存对象缓存系统Memcached,我们在另一篇文章中有提到过“在windows系统下如何安装memcached的讲解”,
- 一个图形化的交互式运行环境,对于编程语言的学习和开发,特别是可视化方面,提供了极大的便利。比如在window上使用R语言进行绘图,在R语言自
- PYTHON首先要安装scapy模块PY3的安装scapy-python3,使用PIP安装就好了,注意,PY3无法使用pyinstaller
- 这篇文章主要介绍了windows环境中利用celery实现简单任务队列过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定
- 什么是formset我们知道forms组件是用来做表单验证,更准确一点说,forms组件是用来做数据库表中一行记录的验证。有forms组件不
- PHP下载图片后文件打开显示损坏问题用php写个图片下载方法,测试发现下载的图片大小都没问题,但是无法打开文件。解决方法如下:首先打开文件下
- 前言Python爬虫要经历爬虫、爬虫被限制、爬虫反限制的过程。当然后续还要网页爬虫限制优化,爬虫再反限制的一系列道高一尺魔高一丈的过程。爬虫
- Python 列表理解及使用方法列表是最常用的Python最常用的数据类型,它和其它序列一样,可以进行包括索引,切片,加,乘,检查成员的操作
- 本文实例讲述了python提取字典key列表的方法。分享给大家供大家参考。具体如下:这段代码可以把字典的所有key输出为一个数组d2 = {
- 线性回归是一种常见的机器学习算法,也是人工智能中常用的算法。它是一种用于预测数值型输出变量与一个或多个自变量之间线性关系的方法。例如,你可以
- 目录wsgi 相关概念CGIWSGIASGIcgi 示例cgi脚本cgi服务实现wsgirefwsgi 小结小技巧python web开发中