解决pytorch中的kl divergence计算问题
作者:jingxian 发布时间:2023-11-12 11:02:00
偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下
一篇关于KL散度、JS散度以及交叉熵对比的文章
kl divergence 介绍
KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法。计算公式:
可以发现,P 和 Q 中元素的个数不用相等,只需要两个分布中的离散元素一致。
举个简单例子:
两个离散分布分布分别为 P 和 Q
P 的分布为:{1,1,2,2,3}
Q 的分布为:{1,1,1,1,1,2,3,3,3,3}
我们发现,虽然两个分布中元素个数不相同,P 的元素个数为 5,Q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。
当 x = 1时,在 P 分布中,“1” 这个元素的个数为 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 这个元素的个数为 5,故 Q(x = 1) = 5/10 = 0.5
同理,
当 x = 2 时,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1
当 x = 3 时,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4
把上述概率带入公式:
至此,就计算完成了两个离散变量分布的KL散度。
pytorch 中的 kl_div 函数
pytorch中有用于计算kl散度的函数 kl_div
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')
计算 D (p||q)
1、不用这个函数的计算结果为:
与手算结果相同
2、使用函数:
(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)
注意:
1、函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log
2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和
好别扭的用法,不知道为啥官方把它设计成这样
补充:pytorch 的KL divergence的实现
看代码吧~
import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
p = F.softmax(p_logit, dim=-1)
_kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
- F.log_softmax(q_logit, dim=-1)), 1)
return torch.mean(_kl)
来源:https://blog.csdn.net/wwyy2018/article/details/101599862
猜你喜欢
- 话说用了就要有点产出,要不然过段时间又忘了,所以在这里就记录一下试用Kafka的安装过程和php扩展的试用。实话说,如果用于队列的话,跟PH
- ASPJPEG组件是Persits出品的共享软件,试用期为30天,您可以在这里下载:http://www.persits.com/aspjp
- Scrapy 结构概述:一、下载器中间件(Downloader Middleware)如上图标号4、5处所示,下载器中间件用于处理scrap
- 当需要远程办公时,使用pycharm远程连接服务器时必要的。PyCharm提供两种远程调试(Remote Debugging)的方式:配置远
- pandas删除部分数据后重新索引在使用pandas时,由于隔行读取删除了部分数据,导致删除数据后的索引不连续:原数据删除部分数据后在绑定p
- 场景可能是你用不到,但是我遇到了这样一个问题,就是我想详细了解我的竞争对手的网站(电商类)销售情况和新品上架情况,但是我总不至于像盯盘一样,
- 一、问题描述当用JS调用form的方法submit直接提交form的时候,submit事件不响应。为什么?知道的请回复。类比一下,我用inp
- 一、背景最近学校校园网不知道是什么情况,总出现掉线的情况。每次掉线都需要我手动打开web浏览器重新进行账号密码输入,重新进行登录。系统的问题
- 毫无疑问,我们生活在编辑器的最好年代,Vim是仅在Vi之下的神级编辑器,而脱胎于Vim的NeoVim则是这个时代最好的编辑器,没有之一。异步
- 假如页面上有很多条记录,很多情况下,对这些信息按照字母表降序排序会比传统的升序排序显示效率更高。采用你熟悉的ORDER BY 子句,你可以很
- 这次主要是爬了京东上一双鞋的相关评论:将数据保存到excel中并可视化展示相应的信息主要的python代码如下:文件1#将excel中的数据
- 使用 str.join() 方法打印不带括号的元组,例如 result = ','.join(my_tuple)。 str.
- 介绍Zmail 使得在python3中发送和接受邮件变得更简单。你不需要手动添加服务器地址、端口以及适合的协议,zmail会帮你完成。此外,
- 一、安装第三方库是可能出现如下错误提示:二、解决办法:最好的解决办法可以通过“Pycharm”左下角
- 一、作业回顾1、格式化输出与%百分号以下结果中,可以正常输出“50%及格”语句是(B)A、print
- 前记上一遍文章《Python中Async语法协程的实现》介绍了Python是如何以生成器来实现协程的以及Python Asyncio通过Fu
- 需求描述有时候我们会基于已有数据生成一列在表格中,类似于下面的class BaseSchema(models.Model): ... def
- 概述今天我们要来做一个进阶的花分类问题. 不同于之前做过的鸢尾花, 这次我们会分析 102 中不同的花. 是不是很上头呀.预处理导包常规操作
- 使用pandas下的cumsum函数cumsum:计算轴向元素累积加和,返回由中间结果组成的数组.重点就是返回值是"由中间结果组成
- YOLOv5的Backbone设计在上一篇文章《YOLOV5的anchor设定》中我们讨论了anchor的产生原理和检测过程,对YOLOv5