网络编程
位置:首页>> 网络编程>> Python编程>> pytorch 实现二分类交叉熵逆样本频率权重

pytorch 实现二分类交叉熵逆样本频率权重

作者:*小呆  发布时间:2021-04-29 00:25:29 

标签:pytorch,交叉熵,样本,权重

通常,由于类别不均衡,需要使用weighted cross entropy loss平衡。


def inverse_freq(label):
"""
输入label [N,1,H,W],1是channel数目
"""
   den = label.sum() # 0
   _,_,h,w= label.shape
   num = h*w
   alpha = den/num # 0
   return torch.tensor([alpha, 1-alpha]).cuda()
# train
...
loss1 = F.cross_entropy(out1, label.squeeze(1).long(), weight=inverse_freq(label))

补充:Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)

在Pytorch中的交叉熵函数的血泪史要从nn.CrossEntropyLoss()这个损失函数开始讲起。

从表面意义上看,这个函数好像是普通的交叉熵函数,但是如果你看过一些Pytorch的资料,会告诉你这个函数其实是softmax()和交叉熵的结合体。

然而如果去官方看这个函数的定义你会发现是这样子的:

pytorch 实现二分类交叉熵逆样本频率权重

哇,竟然是nn.LogSoftmax()和nn.NLLLoss()的结合体,这俩都是什么玩意儿啊。再看看你会发现甚至还有一个损失叫nn.Softmax()以及一个叫nn.nn.BCELoss()。

我们来探究下这几个损失到底有何种关系。

nn.Softmax和nn.LogSoftmax

首先nn.Softmax()官网的定义是这样的:

pytorch 实现二分类交叉熵逆样本频率权重

嗯...就是我们认识的那个softmax。那nn.LogSoftmax()的定义也很直观了:

pytorch 实现二分类交叉熵逆样本频率权重

果不其然就是Softmax取了个log。可以写个代码测试一下:


import torch
import torch.nn as nn

a = torch.Tensor([1,2,3])
#定义Softmax
softmax = nn.Softmax()
sm_a = softmax=nn.Softmax()
print(sm)
#输出:tensor([0.0900, 0.2447, 0.6652])

#定义LogSoftmax
logsoftmax = nn.LogSoftmax()
lsm_a = logsoftmax(a)
print(lsm_a)
#输出tensor([-2.4076, -1.4076, -0.4076]),其中ln(0.0900)=-2.4076

nn.NLLLoss

上面说过nn.CrossEntropy()是nn.LogSoftmax()和nn.NLLLoss的结合,nn.NLLLoss官网给的定义是这样的:

The negative log likelihood loss. It is useful to train a classification problem with C classes

pytorch 实现二分类交叉熵逆样本频率权重

负对数似然损失 ,看起来好像有点晦涩难懂,写个代码测试一下:


import torch
import torch.nn

a = torch.Tensor([[1,2,3]])
nll = nn.NLLLoss()
target1 = torch.Tensor([0]).long()
target2 = torch.Tensor([1]).long()
target3 = torch.Tensor([2]).long()

#测试
n1 = nll(a,target1)
#输出:tensor(-1.)
n2 = nll(a,target2)
#输出:tensor(-2.)
n3 = nll(a,target3)
#输出:tensor(-3.)

看起来nn.NLLLoss做的事情是取出a中对应target位置的值并取负号,比如target1=0,就取a中index=0位置上的值再取负号为-1,那这样做有什么意义呢,要结合nn.CrossEntropy往下看。

nn.CrossEntropy

看下官网给的nn.CrossEntropy()的表达式:

pytorch 实现二分类交叉熵逆样本频率权重

看起来应该是softmax之后取了个对数,写个简单代码测试一下:


import torch
import torch.nn as nn

a = torch.Tensor([[1,2,3]])
target = torch.Tensor([2]).long()
logsoftmax = nn.LogSoftmax()
ce = nn.CrossEntropyLoss()
nll = nn.NLLLoss()

#测试CrossEntropyLoss
cel = ce(a,target)
print(cel)
#输出:tensor(0.4076)

#测试LogSoftmax+NLLLoss
lsm_a = logsoftmax(a)
nll_lsm_a = nll(lsm_a,target)
#输出tensor(0.4076)

看来直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一样的结果。为什么这样呢,回想下交叉熵的表达式:

pytorch 实现二分类交叉熵逆样本频率权重

其中y是label,x是prediction的结果,所以其实交叉熵损失就是负的target对应位置的输出结果x再取-log。这个计算过程刚好就是先LogSoftmax()再NLLLoss()。

所以我认为nn.CrossEntropyLoss其实应该叫做softmaxloss更为合理一些,这样就不会误解了。

nn.BCELoss

你以为这就完了吗,其实并没有。还有一类损失叫做BCELoss,写全了的话就是Binary Cross Entropy Loss,就是交叉熵应用于二分类时候的特殊形式,一般都和sigmoid一起用,表达式就是二分类交叉熵:

pytorch 实现二分类交叉熵逆样本频率权重

直觉上和多酚类交叉熵的区别在于,不仅考虑了pytorch 实现二分类交叉熵逆样本频率权重的样本,也考虑了pytorch 实现二分类交叉熵逆样本频率权重的样本的损失。

来源:https://blog.csdn.net/qq_39575835/article/details/104336713

0
投稿

猜你喜欢

  • print() 函数使用以 % 开头的转换说明符对各种类型的数据进行格式化输出。转换说明符(Conversion Specifier)只是一
  • python中正则表达式中的匹配次数问题网上有很多解释,最多的就是*匹配0或者无数次,+匹配1次或无数次,?匹配0次或者1次。可是虽然这个文
  • 最近有个需求,用多线程比较合适,但是我需要每个线程的返回值,这就需要我在threading.Thread的基础上进行封装import thr
  • Requests具有完备的中英文文档, 能完全满足当前网络的需求, 它使用了urllib3, 拥有其所有的特性!最近在学python自动化,
  • XML(可扩展标记语言)已成为Web应用中数据表示和数据交换的标准,随着Internet的快速发展,尤其是电子商务,Web服务等应用的广泛使
  • 前言 绝大多数的Oracle数据库性能问题都是由于数据库设计不合理造成的,只有少部分问题根植于Database Buffer、Share P
  • 本文是小编针对JS删除数组里的某个元素这个大家经常遇到的经典问题整理了在各种情况下的函数写法以及遇到问题的分析,以下是全部内容:删除数组指定
  • cmake-2.8.3.tar.gzmysql-5.5.8.tar.gz一,cmake-2.8.3的安装:tar -zxf cmake-2.
  • 1、凸包检测与凸缺陷定义凸包是将最外层的点连接起来构成的凸多边形,它能包含点击中所有的点。物体的凸包检测常应用在物体识别、手势识别及边界检测
  • 模板继承是ThinkPHP3.1.2版本添加的一项更加灵活的模板布局方式,模板继承不同于模板布局,甚至来说,应该在模板布局的上层。模板继承其
  • 通过亲密性原则,我们可以将一个页面中的元素按照某种逻辑理解上的差异划分成不同的元素组合;再通过对齐原则,使这些不同的元素组合在视觉上看起来彼
  • 今天同事 明城 在项目中碰到一个 BUG,代码具体如下:<!DOCTYPE html PUBLIC "-//W3C//DTD
  • 在使用selenium去获取淘宝商品信息时会遇到登录界面这个登录界面处理的难度在于滑动验证的实现,有的人使用微博登录,避免了滑动验证,那可不
  • NumPy 比一般的 Python 序列提供更多的索引方式。除了之前看到的用整数和切片的索引外,数组可以由整数数组索引、布尔索引及花式索引。
  • ie的javascript失效了,不是设置的问题那么就可能是以下几点问题了~安装KAV可能会破坏系统的javascript关联,失javas
  • 去过新浪或者搜狐吗?虽然我们都不愿意看广告,但是它们做广告的技术我们却应该学到手,这不,又一种很流行的做法儿,做成那种两边对称的对联式的广告
  • 1.zip用法简介在python 3.x系列中,zip方法返回的为一个zip object可迭代对象。class zip(object):&
  • 我是这样来做DIV布局代码的.不知道说的清楚不清楚,凑和看吧我把class分为2种,布局class,风格class,布局class是骨架,风
  • 去年淘宝做了个“胖子”项目,就是把网页的默认宽度从780提升到了950。也就是说,基本放弃了800×600的用户(没有完全放弃,如果你仔细研
  • 1、如何认识可视化?需要指出的是,虽然不同绘图工具包的功能、效果会有差异,但在常用功能上相差并不是很大。与选择哪种绘图工具包相比,更重要的是
手机版 网络编程 asp之家 www.aspxhome.com