PyTorch中topk函数的用法详解
作者:咆哮的阿杰 发布时间:2022-11-08 13:55:52
标签:PyTorch,topk
听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。
用法
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
input:一个tensor数据
k:指明是得到前k个数据以及其index
dim: 指定在哪个维度上排序, 默认是最后一个维度
largest:如果为True,按照大到小排序; 如果为False,按照小到大排序
sorted:返回的结果按照顺序返回
out:可缺省,不要
topk最常用的场合就是求一个样本被网络认为前k个最可能属于的类别。我们就用这个场景为例,说明函数的使用方法。
假设一个,N是样本数目,一般等于batch size, D是类别数目。我们想知道每个样本的最可能属于的那个类别,其实可以用torch.max得到。如果要使用topk,则k应该设置为1。
import torch
pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print(indices)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=1, keepdim=True)
print(indices_max == indices)
# pred
tensor([[-0.1480, -0.9819, -0.3364, 0.7912, -0.3263],
[-0.8013, -0.9083, 0.7973, 0.1458, -0.9156],
[-0.2334, -0.0142, -0.5493, 0.0673, 0.8185],
[-0.4075, -0.1097, 0.8193, -0.2352, -0.9273]])
# indices, shape为 【4,1】,
tensor([[3], #【0,0】代表 第一个样本最可能属于第一类别
[2], # 【1, 0】代表第二个样本最可能属于第二类别
[4],
[2]])
# indices_max等于indices
tensor([[True],
[True],
[True],
[True]])
现在在尝试一下k=2
import torch
pred = torch.randn((4, 5))
print(pred)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True) # k=2
print(indices)
# pred
tensor([[-0.2203, -0.7538, 1.8789, 0.4451, -0.2526],
[-0.0413, 0.6366, 1.1155, 0.3484, 0.0395],
[ 0.0365, 0.5158, 1.1067, -0.9276, -0.2124],
[ 0.6232, 0.9912, -0.8562, 0.0148, 1.6413]])
# indices
tensor([[2, 3],
[2, 1],
[2, 1],
[4, 1]])
可以发现indices的shape变成了【4, k】,k=2。
其中indices[0] = [2,3]。其意义是说明第一个样本的前两个最大概率对应的类别分别是第3类和第4类。
大家可以自行print一下values。可以发现values的shape和indices的shape是一样的。indices描述了在values中对应的值在pred中的位置。
来源:https://blog.csdn.net/qq_34914551/article/details/103738160
0
投稿
猜你喜欢
- 目录一、环境准备二、问题分析三、spider四、item五、setting六、pipelines七、middlewares八、使用jupyt
- 图片的间隙Q:我有一张大图片,把它切割后在Dreamweaver中进行拼接,可是总是有间隙,不知为什么?A:不知你是否把表格的边距、间距和边
- 这样做的好处是:利用表格来装载数据,不言而喻是最好的,你可以很灵活的为每个单元格定义样式。下面是具体的做法首先在photoshop设计一个效
- 因为要用到过滤一组中重复的数据,使之变成没有重复的一组数据的功能,百度了一下,居然有朋友乱写,而且比较多,都没有认真测试过,只对字符可以,但
- 越来越多的网站在logo中添加叶子元素,而此类logo又常常使用绿色,这可以给人希望、清新、健康的感觉,从而很容易被接受和认可。今天我们又收
- 本文实例为大家分享了python实现登录与注册系统的具体代码,供大家参考,具体内容如下实现功能1.调用文本文件里的用户信息2.可以将注册信息
- 前言FlashText 算法是由 Vikash Singh 于2017年发表的大规模关键词替换算法,这个算法的时间复杂度仅由文本长度(N)决
- 1.定义在某些情况下,一个类的对象是有限且固定的,比如季节类,它只有 4 个对象;再比如行星类,目前只有 8 个对象。这种实例有限且固定的类
- 1 Video介绍引用我翻译文档《在HTML5页面中嵌入音频和视频》中的介绍文字:“当今,在网页上嵌入视频且所有用户不管使用任何浏览器或者操
- <!DOCTYPE HTML PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN&
- 这个技巧将教你如何用css做出漂亮的文本按钮,有活力的按钮将节省你很多制作图片的时间,也能让你一天的工作中成为一个快乐的人,让我们一起看看效
- 前言说到幻影坦克,我就想起红色警戒里的……幻影坦克(Mirage Tank),《红色警戒2》以及《尤里的复仇》中盟军的一款伪装坦克,盟军王牌
- 定义行为要定义行为,通过继承 yii\base\Behavior 或其子类来建立一个类。如:namespace app\components
- <script language="vbscript" runat="s
- 初学者可以看看。在的img标签有两个属性分别为alt和title,对于很多初学者而言对这两个属性的正确使用都还抱有迷惑,当然这其中一部分原因
- EM算法实例通过实例可以快速了解EM算法的基本思想,具体推导请点文末链接。图a是让我们预热的,图b是EM算法的实例。这是一个抛硬币的例子,H
- 利用Python正则表达式匹配字符串中的http链接。主要难点是用正则表示出http 链接的模式。import repattern = re
- 索引是以表列为基础的数据库对象。索引中保存着表中排序的索引列,并且纪录了索引列在数据库表中的物理存储位置,实现了表中数据的逻辑排序。通过索引
- 磁盘搜索是性能的很大瓶颈。这个问题在数据大量增长以至于无法使用有效的缓存时尤为明显。或多或少随即访问大数据库时,就必然会有至少一次磁盘搜索来
- Logistic Regression Classifier逻辑回归主要思想就是用最大似然概率方法构建出方程,为最大化方程,利用牛顿梯度上升