PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失
作者:gy笨瓜 发布时间:2022-01-12 19:39:19
torch.nn.CrossEntropyLoss交叉熵损失
本文只考虑基本情况,未考虑加权。
torch.nnCrossEntropyLosss使用的公式
目标类别采用one-hot编码
其中,class表示当前样本类别在one-hot编码中对应的索引(从0开始),
x[j]表示预测函数的第j个输出
公式(1)表示先对预测函数使用softmax计算每个类别的概率,再使用log(以e为底)计算后的相反数表示当前类别的损失,只表示其中一个样本的损失计算方式,非全部样本。
每个样本使用one-hot编码表示所属类别时,只有一项为1,因此与基本的交叉熵损失函数相比,省略了其它值为0的项,只剩(1)所表示的项。
sample
torch.nn.CrossEntropyLoss使用流程
torch.nn.CrossEntropyLoss为一个类,并非单独一个函数,使用到的相关简单参数会在使用中说明,并非对所有参数进行说明。
首先创建类对象
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")
参数reduction默认为"mean",表示对所有样本的loss取均值,最终返回只有一个值
参数reduction取"none",表示保留每一个样本的loss
计算损失
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 与上述【sample】计算一致
实际计算损失值调用函数时,传入pred预测值与class_index类别索引
在传入每个类别时,class_index应为一维,长度为样本个数,每个元素表示对应样本的类别索引,非one-hot编码方式传入
测试torch.nn.CrossEntropyLoss的reduction参数为默认值"mean"
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 与上述【sample】计算一致
交叉熵损失nn.CrossEntropyLoss()的真正计算过程
对于多分类损失函数Cross Entropy Loss,就不过多的解释,网上的博客不计其数。在这里,讲讲对于CE Loss的一些真正的理解。
首先大部分博客给出的公式如下:
其中p为真实标签值,q为预测值。
在低维复现此公式,结果如下。在此强调一点,pytorch中CE Loss并不会将输入的target映射为one-hot编码格式,而是直接取下标进行计算。
import torch
import torch.nn as nn
import math
import numpy as np
#官方的实现
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#输出 tensor(1.1142)
#自己实现
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
output = 0
length = len(target)
for i in range(length):
hou = 0
for j in input[i]:
hou += np.log(input[i][target[i]])
output += -hou
return np.around(output / length, 4)
print(cross_entorpy(input, target))
#输出 3.8162
我们按照官方给的CE Loss和根据公式得到的答案并不相同,说明公式是有问题的。
正确公式
实现代码如下
import torch
import torch.nn as nn
import math
import numpy as np
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#输出 tensor(1.1142)
#%%
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
output = 0
length = len(target)
for i in range(length):
hou = 0
for j in input[i]:
hou += np.exp(j)
output += -input[i][target[i]] + np.log(hou)
return np.around(output / length, 4)
print(cross_entorpy(input, target))
#输出 1.1142
对比自己实现的公式和官方给出的结果,可以验证公式的正确性。
观察公式可以发现其实nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。
nn.logSoftmax(),公式如下
nn.NLLLoss(),公式如下
将nn.logSoftmax()作为变量带入nn.NLLLoss()可得
因为
可看做一个常量,故上式可化简为:
对比nn.Cross Entropy Loss公式,结果显而易见。
验证代码如下。
import torch
import torch.nn as nn
import math
import numpy as np
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
# 输出为tensor(1.1142)
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input=m(input)
output = loss(input, target)
print(output)
# 输出为tensor(1.1142)
综上,可得两个结论
1.nn.Cross Entropy Loss的公式。
2.nn.Cross Entropy Loss为nn.logSoftmax()和nn.NLLLoss()的整合版本。
来源:https://blog.csdn.net/u012633319/article/details/111093144


猜你喜欢
- 1.循环删除 #这个是我选中其中的一个分支进行右键清空操作时进行的处理for i in range(self.tree.currentIte
- 1.Python是如何进行内存管理的?答:从三个方面来说,一对象的引用计数机制,二垃圾回收机制,三内存池机制一、对象的引用计数机制Pytho
- 本文介绍了VUE预渲染及遇到的坑,分享给大家,具体如下:npm install -D prerender-spa-plugin修改webpa
- 本文实例讲述了Python Excel到CSV的转换程序。分享给大家供大家参考,具体如下:题目如下:利用第十二章的openpyxl模块,编程
- 2007年1月,国务院 * 了中央“一号文件”,文件中对加快农业信息化建设有了更明确的部署,为新农村建
- 一、性能度量性能度量目的是对学习期的泛华能力进行评估,性能度量反映了任务需求,在对比不同算法的泛华能力时,使用不同的性能度量往往会导致不同的
- Cloudinary提供了一个API,用于将图像、视频和任何其他类型的文件上传到云端。上传到Cloudinary的文件通过安全备份和修订历史
- js中自动清除ie缓存方法 — 常用 对于动态文件,比如 index.asp?id=... 或者 index.aspx?id=... 相信有
- 逐步回归的基本思想是将变量逐个引入模型,每引入一个解释变量后都要进行F检验,并对已经选入的解释变量逐个进行t检验,当原来引入的解释变量由于后
- 本文实例讲述了python监控网站运行异常并发送邮件的方法。分享给大家供大家参考。具体如下:这是一个简单的python开发的监控程序,当指定
- 有些时候我们不得已要利用values来反向查询key,有没有简单的方法呢?下面我给大家列举一些方法,方便大家使用python3>>
- 前段时间我编写了一个工业控制的软件,在使用中一直存在一个问题,就是当软件检索设备时,因为这个功能执行的时间比较长,导致GUI界面假死,让用户
- 前言我的 App 项目的 API 部分是使用 Django REST Framework 来搭建的,它可以像搭积木一样非常方便地搭出 API
- 在pytorch的CNN代码中经常会看到x.view(x.size(0), -1)首先,在pytorch中的view()函数就是用来改变te
- Go语言也称 Golang,兼具效率、性能、安全、健壮等特性。Go语言从底层原生支持并发,无须第三方库、开发者的编程技巧和开发经验就可以轻松
- 今天谈一下关于python中input的一些基本用法(写给新手入门之用,故只谈比较实用的部分)。首先,我们可以看一下官方文档给我们的解释(在
- Windows 10 x64macOS Sierra 10.12.4Python 2.7准备好装哔~了么,来吧,做个真正意义上的绿色小软件W
- 如果一个应用程序需要登录,则它必须知道当前用户执行了什么操作。因此ASP.NET在展示层提供了一套自己的SESSION会话对象,而ABP则提
- 1. mysql_where子句_聚合函数# ### part 单表查询""" select ... from
- 实例如下所示:from xml.etree.cElementTree import ElementTree,Elementimport xl