pytorch中交叉熵损失(nn.CrossEntropyLoss())的计算过程详解
作者:aift 发布时间:2021-06-03 09:28:09
标签:pytorch,交叉熵损失,nn.CrossEntropyLoss
公式
首先需要了解CrossEntropyLoss的计算过程,交叉熵的函数是这样的:
其中,其中yi表示真实的分类结果。这里只给出公式,关于CrossEntropyLoss的其他详细细节请参照其他博文。
测试代码(一维)
import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)
print("网络输出为5类:")
print(output)
print("要计算label的类别:")
print(label)
print("计算loss的结果:")
print(loss)
first = 0
for i in range(1):
first = -output[i][label[i]]
second = 0
for i in range(1):
for j in range(5):
second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)
测试代码(多维)
import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(3, 5, requires_grad=True)
label = torch.empty(3, dtype=torch.long).random_(5)
loss = criterion(output, label)
print("网络输出为3个5类:")
print(output)
print("要计算loss的类别:")
print(label)
print("计算loss的结果:")
print(loss)
first = [0, 0, 0]
for i in range(3):
first[i] = -output[i][label[i]]
second = [0, 0, 0]
for i in range(3):
for j in range(5):
second[i] += math.exp(output[i][j])
res = 0
for i in range(3):
res += (first[i] + math.log(second[i]))
print("自己的计算结果:")
print(res/3)
nn.CrossEntropyLoss()中的计算方法
注意:在计算CrossEntropyLosss时,真实的label(一个标量)被处理成onehot编码的形式。
在pytorch中,CrossEntropyLoss计算公式为:
CrossEntropyLoss带权重的计算公式为(默认weight=None):
来源:https://blog.csdn.net/ft_sunshine/article/details/92074842
0
投稿
猜你喜欢
- format函数实现字符串格式化的功能基本语法为:通过 : 和 {} 来控制字符串的操作一、对字符串进行操作1. 不设置指定位置,按默认顺序
- 只能输入中文/** * 22.验证汉字 * 表达式 ^[\u4e00-\u9fa5]{0,}$ * 描述 只能汉字 * 匹配的例子 清清月儿
- 最近帮伙计做了一个从网页抓取股票信息并把相应信息存入MySQL中的程序。使用环境:Python 2.5 for WindowsMySQLdb
- python查找多层嵌套字典的值def find_dic(item, key): if isinstance(it
- 废话不多说了,直接给大家贴代码了,具体代码如下所示:<!DOCTYPE html> <html> <head&
- 一、1、图形显示图素法像素法图素法---矢量图:以图形对象为基本元素组成的图形,如矩形、 圆形像素法---标量图:以像素点为基本单位形成图形
- 如何最准确地统计在线用户数?我们推荐的这个程序据说是目前最好的在线用户数量统计程序。代码如下:'首先要设置好global.asa&n
- 前言python数据类型是不允许改变的,这就意味着如果改变 Number 数据类型的值,将重新分配内存空间。下面话不多说,来看看详细的介绍吧
- 前言golang实现定时任务很简单,只须要简单几步代码即可以完成,最近在做了几个定时任务,想研究一下它内部是怎么实现的,所以将源码过了一遍,
- js表单验证只能是写限定的东西大收集 代码如下:ENTER键可以让光标移到下一个输入框<input onkeydown=&q
- 从这一章开始进入正式的算法学习。首先我们学习经典而有效的分类算法:决策树分类算法。1、决策树算法决策树用树形结构对样本的属性进行分类,是最直
- 本文实例讲述了php+mysqli使用面向对象方式更新数据库的方法,分享给大家供大家参考。具体实现方法如下:<?php//第一步:创建
- 如今WEB的安全问题影响着整个安全界,SQL注入,跨站脚本攻击等攻击受到了关注。 网络安全问题日益变的更加重要,国内依然有很多主机受到此类安
- Python的五个标准数据类型数字字符串列表元组字典一、数字不可变数据类型,存储值为数值1.创建对象,分配数值例:>>>
- 要达到二级名的效果,必须一下条件以及流程:1、必须有一个顶级域名,而且此域名必须做好泛解析并做好指向。2、必须有一台属于你的独立的服务器。泛
- 此文刊登在《程序员》2009年5月期:SQL全名是结构化查询语言(Structured Query Language),一直是后台开发者用来
- 前言最近在学习python 爬虫方面的知识,网上有一博客专栏专门写爬虫方面的,看到用urllib请求有道翻译接口获取翻译结果。发现接口变化很
- 阅读上一篇:交互设计模式(二)-Pagination(分页,标记页数) Tagging(标签)问题摘要用户往往想通过流行或最详尽的主题来浏览
- 1. Python的数据类型上一遍博文已经详细地介绍了Python的数据类型,详见链接Python的变量命名及数据类型。在这里总结一下Pyt
- js浮点数计算有时是不准确的,比如7*0.8 == 7*8/10的值为false,因为7*0.8=5.6000000000000005,乘出