在PyTorch中使用标签平滑正则化的问题
作者:deephub 发布时间:2021-12-26 17:46:54
什么是标签平滑?在PyTorch中如何去使用它?
在训练深度学习模型的过程中,过拟合和概率校准(probability calibration)是两个常见的问题。一方面,正则化技术可以解决过拟合问题,其中较为常见的方法有将权重调小,迭代提前停止以及丢弃一些权重等。另一方面,Platt标度法和isotonic regression法能够对模型进行校准。但是有没有一种方法可以同时解决过拟合和模型过度自信呢?
标签平滑也许可以。它是一种去改变目标变量的正则化技术,能使模型的预测结果不再仅为一个确定值。标签平滑之所以被看作是一种正则化技术,是因为它可以防止输入到softmax函数的最大logits值变得特别大,从而使得分类模型变得更加准确。
在这篇文章中,我们定义了标签平滑化,在测试过程中我们将它应用到交叉熵损失函数中。
标签平滑?
假设这里有一个多分类问题,在这个问题中,目标变量通常是一个one-hot向量,即当处于正确分类时结果为1,否则结果是0。
标签平滑改变了目标向量的最小值,使它为ε。因此,当模型进行分类时,其结果不再仅是1或0,而是我们所要求的1-ε和ε,从而带标签平滑的交叉熵损失函数为如下公式。
在这个公式中,ce(x)表示x的标准交叉熵损失函数,例如:-log(p(x)),ε是一个非常小的正数,i表示对应的正确分类,N为所有分类的数量。
直观上看,标记平滑限制了正确类的logit值,并使得它更接近于其他类的logit值。从而在一定程度上,它被当作为一种正则化技术和一种对抗模型过度自信的方法。
PyTorch中的使用
在PyTorch中,带标签平滑的交叉熵损失函数实现起来非常简单。首先,让我们使用一个辅助函数来计算两个值之间的线性组合。
deflinear_combination(x, y, epsilon):return epsilon*x + (1-epsilon)*y
下一步,我们使用PyTorch
中一个全新的损失函数:nn.Module
.
import torch.nn.functional as F
defreduce_loss(loss, reduction='mean'):return loss.mean() if reduction=='mean'else loss.sum() if reduction=='sum'else loss
classLabelSmoothingCrossEntropy(nn.Module):def__init__(self, epsilon:float=0.1, reduction='mean'):
super().__init__()
self.epsilon = epsilon
self.reduction = reduction
defforward(self, preds, target):
n = preds.size()[-1]
log_preds = F.log_softmax(preds, dim=-1)
loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
nll = F.nll_loss(log_preds, target, reduction=self.reduction)
return linear_combination(loss/n, nll, self.epsilon)
我们现在可以在代码中删除这个类。对于这个例子,我们使用标准的fast.ai pets example
.
from fastai.vision import *
from fastai.metrics import error_rate
# prepare the data
path = untar_data(URLs.PETS)
path_img = path/'images'
fnames = get_image_files(path_img)
bs = 64
np.random.seed(2)
pat = r'/([^/]+)_\d+.jpg$'
data = ImageDataBunch.from_name_re(path_img, fnames, pat, ds_tfms=get_transforms(), size=224, bs=bs) \
.normalize(imagenet_stats)
# train the model
learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.loss_func = LabelSmoothingCrossEntropy()
learn.fit_one_cycle(4)
最后将数据转换成模型可以使用的格式,选择ResNet架构并以带标签平滑的交叉熵损失函数作为优化目标。经过四轮循环后,其结果如下
我们所得结果的错误率仅为7.5%,这对于10行左右的代码来说是完全可以接受的,并且在模型中大多数参数还都选择的是默认设置。
因此,在模型中还有许多参数可以进行调整,从而使得模型的表现性能更好,例如:可以使用不同的优化器、超参数、模型架构等。
结论
在这篇文章中,我们了解了什么是标签平滑以及什么时候去使用它,并且我们还知道了如何在PyTorch中实现它。之后,我们训练了一个先进的计算机视觉模型,仅使用十行代码就识别出了不同品种的猫和狗。
模型正则化和模型校准是两个重要的概念。若想成为一个深度学习的资深玩家,就应该好好地去理解这些能够对抗过拟合和模型过度自信的工具。
作者简介: Dimitris Poulopoulos,是BigDataStack的一名机器学习研究员,同时也是希腊Piraeus大学的博士。曾为欧盟委员会、欧盟统计局、国际货币基金组织、欧洲央行等客户设计过与AI相关的软件。
来源:https://blog.csdn.net/m0_46510245/article/details/105267655
猜你喜欢
- 有时我们有很多文件(如图片),我们需要对每一个文件进行操作。 我们还需要一份文件的名字来进行遍历,这时我们首先需要建立一份文件名单,有时还会
- 前言谷歌出了一个开源的、跨平台的、可定制化的机器学习解决方案工具包,给在线流媒体(当然也可以用于普通的视频、图像等)提供了机器学习解决方案。
- 我们将在下面的例子中使用这个 XML 文档。<?xml version="1.0" encod
- 一、hook在PyTorch中,提供了一个专用的接口使得网络在前向传播过程中能够获取到特征图,这个接口的名称非常形象,叫做hook。可以想象
- 准备篇:1、配置防火墙,开启80端口、3306端口vi /etc/sysconfig/iptables-A INPUT -m state -
- 调用数据库存储过程见下:<%Set Dataconn = Server.CreateObject(&qu
- 查看Tensor尺寸及查看数据类型Tensor尺寸查看命令:x.shape例子:input = torch.randn(20,16,50,3
- SQL Server 2000 清理日志精品教程SQL Server 2000 数据库日志太大!如何清理SQL Server 2000的日志
- (1)OracleServiceSID 数据库服务,这个服务会自动地启动和停止数据库。如果安装了一个数据库,它的缺省启动类型为自动。服务进程
- 01直接生成这类方法是利用基本程序软件包numpy的随机数产生方法来生成各类用于聚类算法数据集合,也是自行制作轮子的生成方法。一、基础类型1
- 官方文档https://developers.weixin.qq.com/miniprogram/dev/framework/open-ab
- 1.字符串的驻留机制字符串:在Python中字符串是基本的数据类型,是一个不可变的字符序列2.什么叫字符串的驻留机制仅保存一份相同且不可变字
- select for update 这个是行级锁 当 commit或者rollback时,锁释放 记得打开事务,比如jdbc里面 setAu
- 教你用Python批量查询关键词微信指数。前期准备安装好Python开发环境及Fiddler抓包工具。前期准备安装好Python开发环境及F
- 神经网络模型一般用来做分类,回归预测模型不常见,本文基于一个用来分类的BP神经网络,对它进行修改,实现了一个回归模型,用来做室内定位。模型主
- 译注:前两天看到一篇不错的英文文章,叫做 How browsers work,该文概要的介绍了浏览器从头到尾的工作机制,包括HTML等的解析
- 前言近来chatGPT挺火的,也试玩了一下,确实挺有意思。这里记录一下在Python中如何去使用chatGPT。本篇文章的实现100%基于
- <%dim total(7,1) total(1,0)="中国经营报"
- paramiko是用python语言写的一个模块,遵循SSH2协议,支持以加密和认证的方式,进行远程服务器的连接。paramiko支持Lin
- HTTP_X_FORWARDED_FOR与REMOTE_ADDR的区别.在Request.ServerVariables中并没有HTTP_X