Pytorch损失函数torch.nn.NLLLoss()的使用
作者:Jeremy_lf 发布时间:2021-02-07 16:08:57
标签:Pytorch,损失函数,torch.nn.NLLLoss()
Pytorch损失函数torch.nn.NLLLoss()
在各种深度学习框架中,我们最常用的损失函数就是交叉熵(torch.nn.CrossEntropyLoss),熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。
交叉熵越小,表示数据越接近真实样本。
交叉熵计算公式
就是我们预测的概率的对数与标签的乘积,当qk->1的时候,它的损失接近零。
nn.NLLLoss
官方文档中介绍称:
nn.NLLLoss输入是一个对数概率向量和一个目标标签,它与nn.CrossEntropyLoss的关系可以描述为:softmax(x)+log(x)+nn.NLLLoss====>nn.CrossEntropyLoss
CrossEntropyLoss()=log_softmax() + NLLLoss()
其中softmax函数又称为归一化指数函数,它可以把一个多维向量压缩在(0,1)之间,并且它们的和为1.
计算公式
示例代码:
import math
z = [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0]
z_exp = [math.exp(i) for i in z]
print(z_exp) # Result: [2.72, 7.39, 20.09, 54.6, 2.72, 7.39, 20.09]
sum_z_exp = sum(z_exp)
print(sum_z_exp) # Result: 114.98
softmax = [round(i / sum_z_exp, 3) for i in z_exp]
print(softmax) # Result: [0.024, 0.064, 0.175, 0.475, 0.024, 0.064, 0.175]
log_softmax
log_softmax是指在softmax函数的基础上,再进行一次log运算,此时结果有正有负,log函数的值域是负无穷到正无穷,当x在0—1之间的时候,log(x)值在负无穷到0之间。
nn.NLLLoss
此时,nn.NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
代码示例:
import torch
input=torch.randn(3,3)
soft_input = torch.nn.Softmax(dim=0)
soft_input(input)
Out[20]:
tensor([[0.7284, 0.7364, 0.3343],
[0.1565, 0.0365, 0.0408],
[0.1150, 0.2270, 0.6250]])
#对softmax结果取log
torch.log(soft_input(input))
Out[21]:
tensor([[-0.3168, -0.3059, -1.0958],
[-1.8546, -3.3093, -3.1995],
[-2.1625, -1.4827, -0.4701]])
假设标签是[0,1,2],第一行取第0个元素,第二行取第1个,第三行取第2个,去掉负号,即[0.3168,3.3093,0.4701],求平均值,就可以得到损失值。
(0.3168+3.3093+0.4701)/3
Out[22]: 1.3654000000000002
#验证一下
loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(0.1365)
nn.CrossEntropyLoss
loss=torch.nn.NLLLoss()
target=torch.tensor([0,1,2])
loss(input,target)
Out[26]: tensor(-0.1399)
loss =torch.nn.CrossEntropyLoss()
input = torch.tensor([[ 1.1879, 1.0780, 0.5312],
[-0.3499, -1.9253, -1.5725],
[-0.6578, -0.0987, 1.1570]])
target = torch.tensor([0,1,2])
loss(input,target)
Out[30]: tensor(0.1365)
以上为全部实验验证两个loss函数之间的关系!!!
来源:https://blog.csdn.net/Jeremy_lf/article/details/102725285
0
投稿
猜你喜欢
- 不到40天,ChatGPT的日活量已突破千万!而当年同样引起轰动的Instagram达到这一成就足足花了355天。这代表着我们正在广泛且快速
- 本文实例讲述了php实现搜索一维数组元素并删除二维数组对应元素的方法。分享给大家供大家参考。具体如下:定义一个一维数组一个二维数组如下$fr
- 10线程同时操作,频繁出现插入同样数据的问题。虽然在插入数据的时候使用了: insert inti tablename(fields....
- 刚刚解决了这个问题,现在记录下来问题描述当使用lambda层加入自定义的函数后,训练没有bug,载入保存模型则显示Nonetype has
- 使用python爬虫其实就是方便,它会有各种工具类供你来使用,很方便。Java不可以吗?也可以,使用httpclient工具、还有一个大神写
- 最后罗嗦一句,本人录入这篇文章用的机器上没有 ASP 环境,所以提供的代码未能进行测试,对这一点本人深表歉意。如果大家发现了代码中的任何问题
- 概述 -------------------------------------------------------------------
- 类的特殊成员之call#!/usr/bin/env python# _*_coding:utf-8 _*_class SpecialMemb
- python3中str默认为Unicode的编码格式Unicode是一32位编码格式,不适合用来传输和存储,所以必须转换成utf-8,gbk
- tkinter库:Python的标准Tk GUI工具包的接口示例:from tkinter import *root = Tk()#你的ui
- CTE(Common Table Expressions)是从SQL Server 2005以后版本才有的。指定的临时命名结果集,这些结果集
- 之前一直用python自带的IDLE写python程序,后来发现有一些限制啥的,于是下载了pycharm作为IDE去处理python新建项目
- 引入:Python中有个logging模块可以完成相关信息的记录,在debug时用它往往事半功倍一、日志级别(从低到高):DEBUG :详细
- pyd文件生成安装easycython库pip install easycythontest.pydef test(): pri
- 前文介绍了Oracle 中实现数据透视表的几种方法,今天我们来看看在 MySQL/MariaDB 中如何实现相同的功能。本文使用的示例数据可
- 数据去重可以使用duplicated()和drop_duplicates()两个方法。DataFrame.duplicated(subset
- golang的单引号转义如题,golang中有时候需要将一个字符串中的单引号再转义一次,比如在两个单引号之间包含一个含有单引号的字符串的情形
- 一、Go语言通道基础概念1.channel产生背景 线程之间进行通信的时候,会因为资源的争夺而产生竟态问
- 这个程序将记数器的数字放在ACCESS数据库中,当然你也能用你希望其它的ODBC数据源.这个程序从URL中读取记数信息.如下:< IM
- sys.argv[]说白了就是一个从程序外部获取参数的桥梁,这个“外部”很关键,因为我们从外部取得的参数可以是多个,所以获得的是一个列表(l