pytorch交叉熵损失函数的weight参数的使用
作者:Nick Blog 发布时间:2021-02-27 15:52:31
首先
必须将权重也转为Tensor的cuda格式;
然后
将该class_weight作为交叉熵函数对应参数的输入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
补充:关于pytorch的CrossEntropyLoss的weight参数
首先这个weight参数比想象中的要考虑的多
你可以试试下面代码
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)
这里的手动计算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加权呢?
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)
手算发现,并不是单纯的那权重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
发现了么,加权后,除以的是权重的和,不是数目的和。
我们再验证一遍:
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明
来源:https://niecongchong.blog.csdn.net/article/details/86594621
猜你喜欢
- 从大规模数据集中寻找物品间的隐含关系被称作关联分析或关联规则学习。过程分为两步:1.提取频繁项集。2.从频繁项集中抽取出关联规则。 频繁项集
- 函数:len()1:作用:返回字符串、列表、字典、元组等长度2:语法:len(str)3:参数:str:要计算的字符串、列表、字典、元组等4
- 1、time模块(※※※※)import time #导入时间模块print(time.time()) #返回当前时间的时间戳,可用于计算程
- 网络安全问题很重要,尤其是保证数据安全,遇到很多在写接口的程序员直接都是明文数据传输,在我看来这是很不专业的。本人提倡经过接口的数据都要进行
- CSS+DIV是网站标准(或称“WEB标准”)中常用的术语之一,通常为了说明与HTML网页设计语言中的表格(table)定位方式的区别,因为
- 1.1. 前言众所周知,安服工程师又叫做Word工程师,在打工或者批量SRC的时候,如果产出很多,又需要一个一个的写报告的情况下会非常的折磨
- 上一篇的DOCTYPE声明好以后,接下来的代码是:<html xmlns="xhtml" ta
- pytorch transform数据处理转c++python推理代码转c++ sdk过程遇到pytorch数据处理的转换1.python代
- 无论是对于刚接触编程的初学者,还是已经工作的程序员,哪一门编程语言更火,更有价值和前景,似乎是永远有争议的话题。下面来对比说以下python
- 数据加密是一种保护数据安全的技术,通过对数据进行编码,使得未经授权的用户无法读取或改动数据。加密是通过使用加密算法和密钥实现的。加密算法是一
- 1.zip用法简介在python 3.x系列中,zip方法返回的为一个zip object可迭代对象。class zip(object):&
- php redis断线重连,pconnect连接失败问题介绍在swoole ,workerman等cli长连接模式下,遇到Redis异常断开
- 上节我们提到解决赋值中等号两边参数不一致的方法可以通过切片,但在Python3中我们可以利用特定的语法更加方便的处理这种情况,如下示例。当带
- 最近遇到一个情景,就是定期生成并发送服务器使用情况报表,按照不同维度统计,涉及python对excel的操作,上网搜罗了一番,大多大同小异,
- 这是我使用python写的第一个类(也算是学习面向对象语言以来正式写的第一个解耦的类),记录下改进的过程。分析需求最初,因为使用time模块
- 前言我在使用mac安装virtualwrapper的时候遇到了问题,搞了好长时间,才弄好,在这里总结一下分享出来,供遇到相同的问题的朋友使用
- 使用工具:Python2.7 点我下载scrapy框架sublime text3一。搭建python(Windows版本) 1.安
- 前言问题需求:拥有两个文件夹,一个保存图片image,一个保存标签文件,要求把标签文件中的标注提取出来,并在图片中画出来相应的思路首先提出各
- 笔者日积月累了许多精彩、实用的Web特效的制作,这些特效几乎都是比较常用的网页特效。现在我就把这些经过
- 概述在绝大部分的开发语言中与实际开发过程中,Dictionary扮演着举足轻重的角色。从我们的数据模型到服务器返回的参数到数据库的应用等等,