Python机器学习pytorch交叉熵损失函数的深刻理解
作者:Ezail_xdu 发布时间:2021-12-11 06:09:40
说起交叉熵损失函数「Cross Entropy Loss」,脑海中立马浮现出它的公式:
我们已经对这个交叉熵函数非常熟悉,大多数情况下都是直接拿来使用就好。但是它是怎么来的?为什么它能表征真实样本标签和预测概率之间的差值?上面的交叉熵函数是否有其它变种?
1.交叉熵损失函数的推导
我们知道,在二分类问题模型:例如逻辑回归「Logistic Regression」、神经网络「Neural Network」等,真实样本的标签为 [0,1],分别表示负类和正类。模型的最后通常会经过一个 Sigmoid 函数,输出一个概率值,这个概率值反映了预测为正类的可能性:概率越大,可能性越大。
Sigmoid 函数的表达式和图形如下所示:
其中 s 是模型上一层的输出,Sigmoid 函数有这样的特点:s = 0 时,g(s) = 0.5;s >> 0 时, g ≈ 1,s << 0 时,g ≈ 0。显然,g(s) 将前一级的线性输出映射到 [0,1] 之间的数值概率上。这里的 g(s) 就是交叉熵公式中的模型预测输出 。
我们说了,预测输出即 Sigmoid 函数的输出表征了当前样本标签为 1 的概率:
很明显,当前样本标签为 0 的概率就可以表达成:
重点来了,如果我们从极大似然性的角度出发,把上面两种情况整合到一起:
不懂极大似然估计也没关系。我们可以这么来看:
当真实样本标签 y = 0 时,上面式子第一项就为 1,概率等式转化为:
当真实样本标签 y = 1 时,上面式子第二项就为 1,概率等式转化为:
两种情况下概率表达式跟之前的完全一致,只不过我们把两种情况整合在一起了。
重点看一下整合之后的概率表达式,我们希望的是概率 P(y|x) 越大越好。首先,我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。则有:
我们希望 log P(y|x) 越大越好,反过来,只要 log P(y|x) 的负值 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x)即可。则得到损失函数为:
非常简单,我们已经推导出了单个样本的损失函数,是如果是计算 N 个样本的总的损失函数,只要将 N 个 Loss 叠加起来就可以了:
这样,我们已经完整地实现了交叉熵损失函数的推导过程。
2. 交叉熵损失函数的直观理解
我已经知道了交叉熵损失函数的推导过程。但是能不能从更直观的角度去理解这个表达式呢?而不是仅仅记住这个公式。好问题!接下来,我们从图形的角度,分析交叉熵函数,加深理解。
首先,还是写出单个样本的交叉熵损失函数:
我们知道,当 y = 1 时
这时候,L 与预测输出的关系如下图所示:
看了 L 的图形,简单明了!横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大。因此,函数的变化趋势完全符合实际需要的情况。
当 y = 0 时:
这时候,L 与预测输出的关系如下图所示:
同样,预测输出越接近真实样本标签 0,损失函数 L 越小;预测函数越接近 1,L 越大。函数的变化趋势也完全符合实际需要的情况。
从上面两种图,可以帮助我们对交叉熵损失函数有更直观的理解。无论真实样本标签 y 是 0 还是 1,L 都表征了预测输出与 y 的差距。
另外,重点提一点的是,从图形中我们可以发现:预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大,是一种类似指数增长的级别。这是由 log 函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签 y。
3. 交叉熵损失函数的其它形式
什么?交叉熵损失函数还有其它形式?没错!我刚才介绍的是一个典型的形式。接下来我将从另一个角度推导新的交叉熵损失函数。
这种形式下假设真实样本的标签为 +1 和 -1,分别表示正类和负类。有个已知的知识点是Sigmoid 函数具有如下性质:
这个性质我们先放在这,待会有用。
好了,我们之前说了 y = +1 时,下列等式成立:
如果 y = -1 时,并引入 Sigmoid 函数的性质,下列等式成立:
重点来了,因为 y 取值为 +1 或 -1,可以把 y 值带入,将上面两个式子整合到一起:
这个比较好理解,分别令 y = +1 和 y = -1 就能得到上面两个式子。
接下来,同样引入 log 函数,得到:
要让概率最大,反过来,只要其负数最小即可。那么就可以定义相应的损失函数为:
还记得 Sigmoid 函数的表达式吧?将 g(ys) 带入:
好咯,L 就是我要推导的交叉熵损失函数。如果是 N 个样本,其交叉熵损失函数为:
接下来,我们从图形化直观角度来看。当 y = +1 时:
这时候,L 与上一层得分函数 s 的关系如下图所示:
横坐标是 s,纵坐标是 L。显然,s 越接近正无穷,损失函数 L 越小;s 越接近负无穷,L 越大。
另一方面,当 y = -1 时:
这时候,L 与上一层得分函数 s 的关系如下图所示:
同样,s 越接近负无穷,损失函数 L 越小;s 越接近正无穷,L 越大。
4.总结
本文主要介绍了交叉熵损失函数的数学原理和推导过程,也从不同角度介绍了交叉熵损失函数的两种形式。第一种形式在实际应用中更加常见,例如神经网络等复杂模型;第二种多用于简单的逻辑回归模型。
需要注意的是:第一个公式中的变量是sigmoid输出的值,第二个公式中的变量是sigmoid输入的值。
来源:https://blog.csdn.net/weixin_38526306/article/details/87831201


猜你喜欢
- 前言题目如下:给定一个仅包含大小写字母和空格 ’ ’ 的字符串 s,返回其最后一个单词的长度。如果字
- 一、Scrapy是什么Scrapy 是一个基于 Twisted 的异步处理框架,是纯 Python 实现的爬虫框架,其架构清晰,模块之间的耦
- 大家知道,在js里encodeURIComponent 方法是一个比较常用的编码方法,但因工作需要,在asp里需用到此方法,查了好多资料,没
- 详解Python 模拟实现生产者消费者模式的实例散仙使用python3.4模拟实现的一个生产者与消费者的例子,用到的知识有线程,队列,循环等
- 通过一个简单的实例,来让大家了解一下golang flag包的一个简单的用法package mainimport ( "
- 随着网站访问量的加大,每次从数据库读取都是以效率作为代价的,很多用ACCESS作数据库的更会深有体会,静态页加在搜索时,也会被优先考虑。互联
- 本文实例讲述了Python常用特殊方法。分享给大家供大家参考,具体如下:1 __init__和__new____init__方法用来初始化类
- 一、演示效果b站:虎年烟花演示二、python代码import pygamefrom math import *from pygame.lo
- 很多的朋友一而再,再而三的在Server.Mappath上卡壳,cnbruce也是一遍两遍地重复,还是不能全部解决,所以通过下面的举例,希望
- 我们都知道 vue-router 的动态路由匹配 对组件是原地复用的策略,需要我们在组件中根据不同的 $route 参数展示不同的数据,这在
- 自己写的用js读取配置文件的程序 D:\Useful Stuff\Javascript\mytest.txt 文件内容如下 [plugin_
- 最近需要训练一个生成对抗网络模型,然后开发接口,不得不在一台有显卡的远程linux服务器上进行,所以,趁着这个机会研究了下怎么使用vscod
- 一、数据引擎简介在MySQL 5.1中,MySQL AB引入了新的插件式存储引擎体系结构,允许将存储引擎加载到正在运新的MySQL
- 一、概念路由指的是客户端的请求与服务器处理函数之间的映射关系Express中的路由分3部分组成,分别是请求的类型、请求的URL地址、处理函数
- 这是一个非常简单的解决方案,柱状图中每一条柱都是一个 div,数据的大小呈现在 div 的宽或高上。 查看演示 例子下载实现的原理
- 用法: 按住鼠标左键拖拽一个框后释放洗洗睡了<!DOCTYPE html public "-//W3C//DTD XHTML
- 1.单独使用Pillow包时,图片会弹出新窗口显示:from Pillow import Imageimg = Image.open(
- phpinfo函数phpinfo函数 PHP中提供了PHPInfo()函数,该函数返回 PHP 的所有信息,包括了 PHP 的编译选项及扩充
- 本文采用OpenCV3和Python3 来实现静态图片的人脸识别,采用的是Haar文件级联。 首先需要将OpenCV3源代码中找到data文
- 第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的.以下给出两种思路:第一种思路是比较容易想到的, 就是对