pytorch 实现计算 kl散度 F.kl_div()
作者:Answerlzd 发布时间:2023-04-03 20:16:18
先附上官方文档说明:https://pytorch.org/docs/stable/nn.functional.html
torch.nn.functional.kl_div(input, target, size_average=None, reduce=None, reduction='mean')
Parameters
input – Tensor of arbitrary shape
target – Tensor of the same shape as input
size_average (bool, optional) – Deprecated (see reduction). By default, the losses are averaged over each loss element in the batch. Note that for some losses, there multiple elements per sample. If the field size_average is set to False, the losses are instead summed for each minibatch. Ignored when reduce is False. Default: True
reduce (bool, optional) – Deprecated (see reduction). By default, the losses are averaged or summed over observations for each minibatch depending on size_average. When reduce is False, returns a loss per batch element instead and ignores size_average. Default: True
reduction (string, optional) – Specifies the reduction to apply to the output: 'none' | 'batchmean' | 'sum' | 'mean'. 'none': no reduction will be applied 'batchmean': the sum of the output will be divided by the batchsize 'sum': the output will be summed 'mean': the output will be divided by the number of elements in the output Default: 'mean'
然后看看怎么用:
第一个参数传入的是一个对数概率矩阵,第二个参数传入的是概率矩阵。这里很重要,不然求出来的kl散度可能是个负值。
比如现在我有两个矩阵X, Y。因为kl散度具有不对称性,存在一个指导和被指导的关系,因此这连个矩阵输入的顺序需要确定一下。
举个例子:
如果现在想用Y指导X,第一个参数要传X,第二个要传Y。就是被指导的放在前面,然后求相应的概率和对数概率就可以了。
import torch
import torch.nn.functional as F
# 定义两个矩阵
x = torch.randn((4, 5))
y = torch.randn((4, 5))
# 因为要用y指导x,所以求x的对数概率,y的概率
logp_x = F.log_softmax(x, dim=-1)
p_y = F.softmax(y, dim=-1)
kl_sum = F.kl_div(logp_x, p_y, reduction='sum')
kl_mean = F.kl_div(logp_x, p_y, reduction='mean')
print(kl_sum, kl_mean)
>>> tensor(3.4165) tensor(0.1708)
补充:pytorch中的kl散度,为什么kl散度是负数?
F.kl_div()或者nn.KLDivLoss()是pytroch中计算kl散度的函数,它的用法有很多需要注意的细节。
输入
第一个参数传入的是一个对数概率矩阵,第二个参数传入的是概率矩阵。并且因为kl散度具有不对称性,存在一个指导和被指导的关系,因此这连个矩阵输入的顺序需要确定一下。如果现在想用Y指导X,第一个参数要传X,第二个要传Y。就是被指导的放在前面,然后求相应的概率和对数概率就可以了。
所以,一随机初始化一个tensor为例,对于第一个输入,我们需要先对这个tensor进行softmax(确保各维度和为1),然后再取log;对于第二个输入,我们需要对这个tensor进行softmax。
import torch
import torch.nn.functional as F
a = torch.tensor([[0,0,1.1,2,0,10,0],[0,0,1,2,0,10,0]])
log_a =F.log_softmax(a)
b = torch.tensor([[0,0,1.1,2,0,7,0],[0,0,1,2,0,10,0]])
softmax_b =F.softmax(b,dim=-1)
kl_mean = F.kl_div(log_a, softmax_b, reduction='mean')
print(kl_mean)
为什么KL散度计算出来为负数
先确保对第一个输入进行了softmax+log操作,对第二个参数进行了softmax操作。不进行softmax操作就可能为负。
然后查看自己的输入是否是小数点后有很多位,当小数点后很多位的时候,pytorch下的softmax会产生各维度和不为1的现象,导致kl散度为负,如下所示:
a = torch.tensor([[0.,0,0.000001,0.0000002,0,0.0000007,0]])
log_a =F.log_softmax(a,dim=-1)
print("log_a:",log_a)
b = torch.tensor([[0.,0,0.000001,0.0000002,0,0.0000007,0]])
softmax_b =F.softmax(b,dim=-1)
print("softmax_b:",softmax_b)
kl_mean = F.kl_div(log_a, softmax_b,reduction='mean')
print("kl_mean:",kl_mean)
输出如下,我们可以看到softmax_b的各维度和不为1:
来源:https://blog.csdn.net/Answer3664/article/details/106265132


猜你喜欢
- Object 类型的对象虽然有 toString 方法,但结果却是 [Object Object] 让人没法理解的字符。比如简单的对象:{n
- watch介绍vue中watch用来监听数据的响应式变化.获取数据变化前后的值watch的完整入参watch(监听的数据,副作用函数,配置对
- 本文实例为大家分享了python定时提取实时日志的具体代码,供大家参考,具体内容如下这是一个定时读取 实时日志文件的程序。目标文件是targ
- 前言首先,先说明我只是初步接触yolov7,写这篇文章的主要目的是可以让大家快速应用自己的数据集进行训练。没有接触过yolov5也没有关系,
- 使用tensorflow过程中,训练结束后我们需要用到模型文件。有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练。这时候我
- 前言好记性不如烂笔头!最近在接口测试,以及爬虫相关,需要用到Python中的requests库,之前用过,但是好久没有用又忘了,这次就把这块
- 很多时候我们需要让main函数不退出,让它在后台一直执行,例如:func main() { for i := 0;
- 最早大家都没有给链接加title的习惯,后来因为w3c标准普及,又集体加上了title。从一个极端走到另个极端,于是出现很多怪异现象。两方面
- 前言最近在爬行 nosec.org 的数据,看了下需要模拟登录拿到cookie后才能访问想抓的数据,重要的是 nosec.org 的登录页面
- 列表推导与生成器表达式当我们创建了一个列表的时候,就创建了一个可以迭代的对象:>>> squares=[n*n for n
- 目录实例演示1. axios上传普通文件:2. 大文件导入:结语这次我要讲述的是在React-Flask框架上开发上传组件的技巧。我目前主要
- 文章介绍OpenCV 库中包含很多运算函数,这里着重介绍按位运算的基本原理并举例说明。本篇文章中主要涉及到的函数有:按位与:bitwise_
- 如果你的PHP网站换了空间,必定要对Mysql数据库进行转移,一般的转移的方法,是备份再还原,有点繁琐,而且由于数据库版本的不一样会导致数据
- 和以往的总监会议一样,在某个新功能的总监级别讨论会上,很多人再次又说出了同样的看法:“我们网站的界面设计太烂了,不好看、不好用、而且很乱”。
- Pattern.split方法详解/** * 测试Pattern.split方法 */ @Test public void testPatt
- Python中的单元测试我们先来回顾一下Python中的单元测试方法。下面是一个 Python的单元测试简单的例子:假如我们开发
- 大家好!我是 Sergey Kamardin,是 Mail.Ru 的一名工程师。本文主要介绍如何使用 Go 开发高负载的 WebSocket
- 如下所示:#coding=utf8import csv import logginglogging.basicConfig(level=lo
- 实现思路和详细解读1. 获取 Fashion 数据、处理数据(1)本次实践项目用到的是 Fashion 数据集,包含 10 个类别的服饰灰度
- 前言:我们想要在爬虫中使用xpath、beautifulsoup、正则表达式,css选择器等来提取想要的数据,但是因为scrapy是一个比较