pytorch中torch.topk()函数的快速理解
作者:Neo很努力 发布时间:2023-09-07 10:21:52
标签:pytorch,torch.topk(),函数
函数作用:
该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。
通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。
举个栗子:
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
####################准备一个数组#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
[3,4,5,1,1,1,1,1,1,1,1],
[7,8,9,1,1,1,1,1,1,1,1],
[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
####################打印这个原数组#########################
print('tensor1:')
print(tensor1)
#################使用torch.topk()这个函数##################
print('使用torch.topk()这个函数得到:')
'''k=3代表从原数组中取得3个元素,dim=1表示从原数组中的第一维获取元素
(在本例中是分别从[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、
[7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]这四个数组中获取3个元素)
其中largest=True表示从大到小取元素'''
print(torch.topk(tensor1, k=3, dim=1, largest=True))
#################打印这个函数第一个返回值####################
print('函数第一个返回值topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
#################打印这个函数第二个返回值####################
print('函数第二个返回值topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
#######################运行结果##########################
tensor1:
tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.],
[ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]])
使用torch.topk()这个函数得到:
'得到的values是原数组dim=1的四组从大到小的三个元素值;
得到的indices是获取到的元素值在原数组dim=1中的位置。'
torch.return_types.topk(
values=tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]]),
indices=tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]]))
函数第一个返回值topk[0]如下
tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]])
函数第二个返回值topk[1]如下
tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]])
'''
该函数功能经常用来获取张量或者数组中最大或者最小的元素以及索引位置,是一个经常用到的基本函数。
实例演示
任务一:
取top1(最大值):
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])
任务二:
按行取出topk,将小于topk的置为inf:
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
top_k = 2 # 按行求出每一行的最大的前两个值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
print(pred)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[4],
[4],
[4],
[3]])
tensor([[0.4053],
[1.8823],
[1.7255],
[0.3849]])
tensor([[ True, False, True, True, False],
[ True, False, True, True, False],
[ True, True, False, True, False],
[ True, False, True, False, True]])
tensor([[ -inf, -0.3873, -inf, -inf, 0.4053],
[ -inf, 1.4164, -inf, -inf, 1.8823],
[ -inf, -inf, 1.2590, -inf, 1.7255],
[ -inf, 0.3041, -inf, 0.3849, -inf]])
任务三:
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
[3,4,5,1,1,1,1,1,1,1,1],
[7,8,9,1,1,1,1,1,1,1,1],
[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
# [6,5,4],
# [1,4,7],
# [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接输出topk,会得到两个东西,我们需要的是第二个indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10., 1., 2., 1., 1., 1., 1., 1., 1., 1., 10.],
[ 3., 4., 5., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 7., 8., 9., 1., 1., 1., 1., 1., 1., 1., 1.],
[ 1., 4., 7., 1., 1., 1., 1., 1., 1., 1., 1.]])
直接输出topk,会得到两个东西,我们需要的是第二个indices
torch.return_types.topk(
values=tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]]),
indices=tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]]))
topk[0]如下
tensor([[10., 10., 2.],
[ 5., 4., 3.],
[ 9., 8., 7.],
[ 7., 4., 1.]])
topk[1]如下
tensor([[ 0, 10, 2],
[ 2, 1, 0],
[ 2, 1, 0],
[ 2, 1, 0]])
'''
来源:https://blog.csdn.net/qq_45193872/article/details/119878804


猜你喜欢
- 抢票是并发执行多个进程可以访问同一个文件多个进程共享同一文件,我们可以把文件当数据库,用多个进程模拟多个人执行抢票任务db.tx
- urllib包和http包都是面向HTTP协议的。其中urllib主要用于处理 URL,使用urllib操作URL可以像使用和打开本地文件一
- 废话不多说,直接开干!抖音字符视频在今年火过一段时间。反正我是始终忘不了那段刘耕宏老师本草纲目的音乐…这一次自己也来实
- 一、图像处理1. 灰度图像灰度图像矩阵元素的取值范围通常为 [0,255] 。因此其数据类型一般为8位无符号整数的(in
- 1. 修改pip install默认安装路径一般使用Anaconda时会使用pip install ###来安装各类包,但默认安装路径在C盘
- 图像融合按照一定的比例将两张图片融合在一起addWeighted()方法:参数1第一张图片矩阵参数2第一张图片矩阵的权重参数3第二张图片矩阵
- 引言少年,你在怀着非法的心态看一篇简短的硬核科普!先抛问题:如何杀掉一个正在等待 TCP 连接的 Thread?由于众所周知的原因,在国内使
- 如果你在爬虫过程中有遇到“您的请求太过频繁,请稍后再试”,或者说代码完全正确,可是爬虫过程中突然就访问不了,那么恭喜你,你的爬虫被对方识破了
- 大数据一般是在“云”上玩的,但“云”都是要钱的,而且数据上上下下的也比较麻烦。所以,在本地电脑上快速处理数据的技能还是要的。pandas在比
- MySQL函数CONCAT、CONCAT_WS、GROUP_CONCAT1.concat()函数CONCAT 函数用于将两个字符串连接为一个
- DataTable dt = new DataTable(); dt = ds.Tables["All"].Clone(
- 介绍今天有个不正经的需求,就是要快速做一个restful api的性能测试,要求测试在海量作业数据的情况下客户端分页获取所有作业的性能。因为
- 前言turtle库是Python语言中一个很流行的绘制图像的函数库,可以轻松地绘制出精美的形状和图案,很适合用来引导孩子学习编程。turtl
- 如何用ADO批量更新记录?是的,ADO有这项功能,不过好像用的人不太多(不了解还是不会用呢?):<HTML> &nbs
- “%”的使用格式符描述%s字符串 (采用str()的显示)%r字符串 (采用repr()的显示)%c
- “Be conservative in what you send; be liberal in what you accept. &nbs
- vue3使用computed获取vuex里数据不再是vue2.0里什么mapGetter,mapState那些复杂的获取方式,vue3.0里
- 这篇文章主要介绍了Python统计时间内的并发数代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 学习JQUERY就应该从最基本的学起,基本的就应该是语法了,在这里,我们有必要先温习一下JAVASCRIPT的一些知识。语法就不用说了,都是
- 前面我们给了Tkinter接管Python输入和输出的介绍,我们不难可以想到,能用Tkinter来开发自己的Python代码编辑器.例如可以