pytorch交叉熵损失函数的weight参数的使用
作者:Nick Blog 发布时间:2021-02-27 15:52:31
首先
必须将权重也转为Tensor的cuda格式;
然后
将该class_weight作为交叉熵函数对应参数的输入值。
class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()
补充:关于pytorch的CrossEntropyLoss的weight参数
首先这个weight参数比想象中的要考虑的多
你可以试试下面代码
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,1,1])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.4803)
这里的手动计算是:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803
加权呢?
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,0,0,0,1])
outputs = torch.LongTensor([0,1])
inputs = inputs.view((1,3,2))
outputs = outputs.view((1,2))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE)
loss = ce(inputs,outputs)
print(loss)
tensor(1.6075)
手算发现,并不是单纯的那权重相乘:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113
而是
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075
发现了么,加权后,除以的是权重的和,不是数目的和。
我们再验证一遍:
import torch
import torch.nn as nn
inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5])
outputs = torch.LongTensor([0,1,2,2])
inputs = inputs.view((1,3,4))
outputs = outputs.view((1,4))
weight_CE = torch.FloatTensor([1,2,3])
ce = nn.CrossEntropyLoss(weight=weight_CE)
# ce = nn.CrossEntropyLoss(ignore_index=255)
loss = ce(inputs,outputs)
print(loss)
tensor(1.5472)
手算:
loss1 = 0 + ln(e0 + e0 + e0) = 1.098
loss2 = 0 + ln(e1 + e0 + e1) = 1.86
loss3 = 0 + ln(e2 + e0 + e0) = 2.2395
loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943
求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472
可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明
来源:https://niecongchong.blog.csdn.net/article/details/86594621


猜你喜欢
- 来做一个快速测验-以下代码输出什么?vals := make([]int, 5)for i := 0; i < 5; i++ { va
- 最近在抓取http://skell.sketchengine.eu网页时,发现用requests无法获得网页的全部内容,所以我就用selen
- 对于经常需要表格头部不东,而列表可以滚动,多用于数据比较多的情况,方便查看<!DOCTYPE HTML PUBLIC "-/
- 前面的话分页导航几乎在每个网站都可见,好的分页能给用户带来好的用户体验。本文将详细介绍Bootstrap分页概述在Bootstrap框架中提
- 因为使用python+selenium有时候需要获取当前文件的上一级目录,找了一段时间找到了,在此记录下来;os.path.dirname(
- python中xmltodict使用xml转换成OrderedDict代码 :import xmltodictfrom pprin
- 前言在laravel项目开发中,经常使用到公共函数,那如何在laravel配置全局公共函数呢??下面话不多说了,来一起看看详细的介绍吧方法如
- var str = "pig cat fish、dog horse monkey bear、lion、fox&quo
- 1、问题描述在使用v-model指令实现输入框数据双向绑定,输入值时对应的这个变量的值也随着变化;但是这里不允许使用v-model,需要写一
- 要想从命令行启动mysqld服务器,你应当启动控制台窗口(或“DOS window”)并输入命令:C
- 本次做一个最简单的贪食蛇雏形游戏,就是一个小蛇在画面上移动,我们可以控制蛇的移动方向,但是不能吃东西,蛇不会长大。但是基础的有了,完整版的贪
- tzset()方法重置所使用的库例程的时间转换规则。环境变量TZ指定如何完成此操作。TZ环境变量的标准格式(空格为清楚起见而加的
- 1、通过探测Flash Player的版本,来决定显示Flash内容还是替换内容,避免了过时的Flash插件影响Flash内容的正常显示。2
- vue数据变化被watch监听处理监听当前vue文件数据例如,当前的vue文件的data中有如下属性:data() {  
- 轮播图的根本其实就是缓动函数的封装,如果说轮播图是一辆跑动的汽车,那么缓动函数就是它的发动机,今天本文章就带大家由简入繁,封装属于自己的缓动
- 在学习python的时候,会有一些梗非常不适应,在此列举列表删除和多重循环退出的例子:列表删除里面的坑比如我们有一个列表里面有很多相同的值,
- 使用场景批量合并相同格式的Exce,给DataFrame添加行,给DataFrame添加列使用说明:1.使用某种合并方式(inner/out
- 本文实例讲述了Python求解平方根的方法。分享给大家供大家参考。具体如下:主要通过SICP的内容改写而来。基于newton method求
- Vue学习笔记-3 前言Vue 2.x相比较Vue 1.x而言,升级变化除了实现了Virtual-Dom以外,给使用者最大不适就是移除的组件
- matplotlib窗口图标默认是matplotlib的标志,如果想修改怎么改呢?由于我选择的matplotlib后端是PyQT5,直接查看