PyTorch基础之torch.nn.CrossEntropyLoss交叉熵损失
作者:gy笨瓜 发布时间:2022-01-12 19:39:19
torch.nn.CrossEntropyLoss交叉熵损失
本文只考虑基本情况,未考虑加权。
torch.nnCrossEntropyLosss使用的公式
目标类别采用one-hot编码
其中,class表示当前样本类别在one-hot编码中对应的索引(从0开始),
x[j]表示预测函数的第j个输出
公式(1)表示先对预测函数使用softmax计算每个类别的概率,再使用log(以e为底)计算后的相反数表示当前类别的损失,只表示其中一个样本的损失计算方式,非全部样本。
每个样本使用one-hot编码表示所属类别时,只有一项为1,因此与基本的交叉熵损失函数相比,省略了其它值为0的项,只剩(1)所表示的项。
sample
torch.nn.CrossEntropyLoss使用流程
torch.nn.CrossEntropyLoss为一个类,并非单独一个函数,使用到的相关简单参数会在使用中说明,并非对所有参数进行说明。
首先创建类对象
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="none")
参数reduction默认为"mean",表示对所有样本的loss取均值,最终返回只有一个值
参数reduction取"none",表示保留每一个样本的loss
计算损失
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: tensor([1.5210, 0.6247]) # 与上述【sample】计算一致
实际计算损失值调用函数时,传入pred预测值与class_index类别索引
在传入每个类别时,class_index应为一维,长度为样本个数,每个元素表示对应样本的类别索引,非one-hot编码方式传入
测试torch.nn.CrossEntropyLoss的reduction参数为默认值"mean"
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: loss_function = nn.CrossEntropyLoss(reduction="mean")
In [4]: pred = torch.tensor([[0.0541,0.1762,0.9489],[-0.0288,-0.8072,0.4909]], dtype=torch.float32)
In [5]: class_index = torch.tensor([0, 2], dtype=torch.int64)
In [6]: loss_value = loss_function(pred, class_index)
In [7]: loss_value
Out[7]: 1.073 # 与上述【sample】计算一致
交叉熵损失nn.CrossEntropyLoss()的真正计算过程
对于多分类损失函数Cross Entropy Loss,就不过多的解释,网上的博客不计其数。在这里,讲讲对于CE Loss的一些真正的理解。
首先大部分博客给出的公式如下:
其中p为真实标签值,q为预测值。
在低维复现此公式,结果如下。在此强调一点,pytorch中CE Loss并不会将输入的target映射为one-hot编码格式,而是直接取下标进行计算。
import torch
import torch.nn as nn
import math
import numpy as np
#官方的实现
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#输出 tensor(1.1142)
#自己实现
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
output = 0
length = len(target)
for i in range(length):
hou = 0
for j in input[i]:
hou += np.log(input[i][target[i]])
output += -hou
return np.around(output / length, 4)
print(cross_entorpy(input, target))
#输出 3.8162
我们按照官方给的CE Loss和根据公式得到的答案并不相同,说明公式是有问题的。
正确公式
实现代码如下
import torch
import torch.nn as nn
import math
import numpy as np
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
#输出 tensor(1.1142)
#%%
input=np.array(input)
target = np.array(target)
def cross_entorpy(input, target):
output = 0
length = len(target)
for i in range(length):
hou = 0
for j in input[i]:
hou += np.exp(j)
output += -input[i][target[i]] + np.log(hou)
return np.around(output / length, 4)
print(cross_entorpy(input, target))
#输出 1.1142
对比自己实现的公式和官方给出的结果,可以验证公式的正确性。
观察公式可以发现其实nn.CrossEntropyLoss()是nn.logSoftmax()和nn.NLLLoss()的整合版本。
nn.logSoftmax(),公式如下
nn.NLLLoss(),公式如下
将nn.logSoftmax()作为变量带入nn.NLLLoss()可得
因为
可看做一个常量,故上式可化简为:
对比nn.Cross Entropy Loss公式,结果显而易见。
验证代码如下。
import torch
import torch.nn as nn
import math
import numpy as np
entroy=nn.CrossEntropyLoss()
input=torch.Tensor([[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],[0.1234, 0.5555,0.3211],])
target = torch.tensor([0,1,2])
output = entroy(input, target)
print(output)
# 输出为tensor(1.1142)
m = nn.LogSoftmax()
loss = nn.NLLLoss()
input=m(input)
output = loss(input, target)
print(output)
# 输出为tensor(1.1142)
综上,可得两个结论
1.nn.Cross Entropy Loss的公式。
2.nn.Cross Entropy Loss为nn.logSoftmax()和nn.NLLLoss()的整合版本。
来源:https://blog.csdn.net/u012633319/article/details/111093144
猜你喜欢
- 就目前互联网上大小网站而言,大部分都是采用ASP+ACCESS/SQL Server或者PHP+MySQL来编写;事实上,ASP和MySQL
- QueryCache(下面简称QC)是根据SQL语句来cache的。一个SQL查询如果以select开头,那么MySQL服务器将尝试对其使
- 在很多情况下,我们可能需要控制某一段代码只执行一次,比如做某些初始化操作,如初始化数据库连接等。 对于这种场景,go 为我们提供了 sync
- Windows客户端业务群产品营销主管斯蒂芬最近在向记者示范Internet Explorer 8 Beta2版浏览器的技术特征时标识,与用
- 本文实例讲述了python模拟鼠标拖动操作的方法。分享给大家供大家参考。具体如下:pdf中的书签只有页码,准备把现有书签拖到一个目录中,然后
- 测试环境:1:xp系统2:双显,1680×1050 + 1050×16803:chrome 版本4.14:ff版本3.6chrome是我的默
- 实例如下所示:import matplotlib.pyplot as pltplt.imshow(img)#控制台打印出图像对象的信息,而图
- 代码如下:'返回某年总共有多少天 Function DayOfYear(ByVal y) DayOfYear = DatePart(
- 获取指定日期月份的第一天,你可以使用DATEADD函数,减去指定日期的月份过去了的天数,即可。 代码如下:CREATE FUNC
- 从百度百科中扣去的这个图片轮播代码,图片向左不间断滚动,有停顿:<!DOCTYPE html PUBLIC "-//W3C/
- 对于SQL的Join,在学习起来可能是比较乱的。我们知道,SQL的Join语法有很多inner的,有outer的,有left的,有时候,对于
- 这篇论坛文章(赛迪网技术社区)主要介绍了MySQL数据库主从复制的相关概念及设置方法,详细内容请大家参考下文:MySQL支持单向、异步复制,
- 说下防止PHPDDOS发包的方法 if (eregi("ddos-udp",$read)) { fputs($verbi
- 有这样一个要求,它要创建一个SQL Server查询,其中包括基于事件时刻的累计值。典型的例子就是一个银行账户,因为你每一次都是在不同的时间
- PyQt5+requests实现一个车票查询工具,供大家参考,具体内容如下结构图效果图思路1、search(QPushButton)点击信号
- 1、字符串拼接通过+运算符现有字符串码农飞哥好,,要求将字符串码农飞哥牛逼拼接到其后面,生成新的字符串码农飞哥好,码农飞哥牛逼举个例子:st
- 首先我们有这么一种需求,就是我在一个列表中点击了某个item,跳转到详情界面,那么我就需要把item的实体数据从列表页面传递到详情页面,那么
- 使用 WinHttpRequest 伪造 HTTP 头信息,伪造 Referer 等信息。由于微软封锁了 XmlHttp 对象,所以无法伪造
- 通常的聊天室所采用的程序,也就是Chat程序了,其基本结构原理是不会采用到数据库的。那究竟采用什么技术呢?我们知道ASP变量当中Sessio
- 如何提高SQL Server数据库的性能,该从哪里入手呢?笔者认为,该遵循从外到内的顺序,来改善数据库的运行性能。如下图: 第一层