解决pytorch中的kl divergence计算问题
作者:jingxian 发布时间:2023-11-12 11:02:00
偶然从pytorch讨论论坛中看到的一个问题,KL divergence different results from tf,kl divergence 在TensorFlow中和pytorch中计算结果不同,平时没有注意到,记录下
一篇关于KL散度、JS散度以及交叉熵对比的文章
kl divergence 介绍
KL散度( Kullback–Leibler divergence),又称相对熵,是描述两个概率分布 P 和 Q 差异的一种方法。计算公式:
可以发现,P 和 Q 中元素的个数不用相等,只需要两个分布中的离散元素一致。
举个简单例子:
两个离散分布分布分别为 P 和 Q
P 的分布为:{1,1,2,2,3}
Q 的分布为:{1,1,1,1,1,2,3,3,3,3}
我们发现,虽然两个分布中元素个数不相同,P 的元素个数为 5,Q 的元素个数为 10。但里面的元素都有 “1”,“2”,“3” 这三个元素。
当 x = 1时,在 P 分布中,“1” 这个元素的个数为 2,故 P(x = 1) = 2/5 = 0.4,在 Q 分布中,“1” 这个元素的个数为 5,故 Q(x = 1) = 5/10 = 0.5
同理,
当 x = 2 时,P(x = 2) = 2/5 = 0.4 ,Q(x = 2) = 1/10 = 0.1
当 x = 3 时,P(x = 3) = 1/5 = 0.2 ,Q(x = 3) = 4/10 = 0.4
把上述概率带入公式:
至此,就计算完成了两个离散变量分布的KL散度。
pytorch 中的 kl_div 函数
pytorch中有用于计算kl散度的函数 kl_div
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')
计算 D (p||q)
1、不用这个函数的计算结果为:
与手算结果相同
2、使用函数:
(这是计算正确的,结果有差异是因为pytorch这个函数中默认的是以e为底)
注意:
1、函数中的 p q 位置相反(也就是想要计算D(p||q),要写成kl_div(q.log(),p)的形式),而且q要先取 log
2、reduction 是选择对各部分结果做什么操作,默认为取平均数,这里选择求和
好别扭的用法,不知道为啥官方把它设计成这样
补充:pytorch 的KL divergence的实现
看代码吧~
import torch.nn.functional as F
# p_logit: [batch, class_num]
# q_logit: [batch, class_num]
def kl_categorical(p_logit, q_logit):
p = F.softmax(p_logit, dim=-1)
_kl = torch.sum(p * (F.log_softmax(p_logit, dim=-1)
- F.log_softmax(q_logit, dim=-1)), 1)
return torch.mean(_kl)
来源:https://blog.csdn.net/wwyy2018/article/details/101599862


猜你喜欢
- Windows环境下一、开启 Imagick 扩展1、安装PHP扩展:Imagick,下载地址 https://pecl.php.net/p
- 1. 原始网站https://www.rbsp-ect.lanl.gov/data_pub/rbspa/2. 算法说明进入需要下载的数据所在
- 花了两个多钟在看 ThinkPHP 框架,不想太过深入的知道它的所有高深理论。单纯想知道怎么可以用起来,可以快捷的搭建一个网站。所以是有选择
- 1.首先注册应用,获取 appkey、appsecretapi_url = "https://oapi.dingtalk.com/
- 主要作用为指定图片像素:matplotlib.rcParams[‘figure.figsize']#图片像素 matplotlib.
- 互斥锁在Golang中,互斥锁(Mutex)是一种基本的同步原语,用于实现对共享资源的互斥访问。互斥锁通过在代码中标记临界区来控制对共享资源
- 在ORACLE中,我们可以通过file_id(file#)与block_id(block#)去定位一个数据库对象(object)。例如,我们
- 本文实例讲述了vue组件定义,全局、局部组件,配合模板及动态组件功能。分享给大家供大家参考,具体如下:一、定义一个组件定义一个组件:1. 全
- 应用场景1.需要将大型MP3文件切割成较小的部分以便上传或发送。2.需要从MP3文件中提取特定的音频片段,以便用于其他目的。3.需要快速制作
- 现在网上有很多python2写的爬虫抓取网页图片的实例,但不适用新手(新手都使用python3环境,不兼容python2),所以我用Pyth
- 本文实例为大家分享了python微信好友删除的具体代码,供大家参考,具体内容如下#weixin.py#coding:utf-8# !/usr
- 为了防止某些别有用心的人从外部访问数据库,盗取数据库中的用户姓名、密码、信用卡号等其他重要信息,在我们创建数据库驱动的解决方案时,我们首先需
- System.Data.OleDb.OleDbDataAdapter与System.Data.OleDb.OleDbDataReader的区
- 下面都是我收集的一些比较常用的正则表达式,因为平常可能在表单验证的时候,用到的比较多。特发出来,让各位朋友共同使用。呵呵。匹配中文字符的正则
- 不知道大家有没有一种感觉,每次当使用numpy数组的时候坐标轴总是傻傻分不清楚,然后就会十分的困惑,每次运算都需要去尝试好久才能得出想要的结
- 一、安装约定 mysql安装路径: /usr/local/mysql
- 本文实例讲述了Python实现字典按照value进行排序的方法。分享给大家供大家参考,具体如下:先说几个解决的方法,具体的有时间再细说d =
- 数据增强卷积神经网络非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。2012 年 AlexNet 在 ImageN
- 表单在网页中主要负责数据采集功能。一个表单有三个基本组成部分: 表单标签:这里面包含了处理表单数据所用CGI程序的URL以及数
- 本文实例为大家分享了梅尔倒谱系数实现代码,供大家参考,具体内容如下""" @author: zoutai@fi