Pytorch中的gather使用方法
作者:SY_curry 发布时间:2021-11-22 06:11:49
官方说明
gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor
一般来说有三个参数:输入的变量input、指定在某一维上聚合的dim、聚合的使用的索引index,输出为Tensor类型的结果(index必须为LongTensor类型)。
#参数介绍:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
#当输入为三维时的计算过程:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
#样例:
t = torch.Tensor([[1,2],[3,4]])
torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
# 1 1
# 4 3
#[torch.FloatTensor of size 2x2]
实验
用下面的代码在二维上做测试,以便更好地理解
t = torch.Tensor([[1,2,3],[4,5,6]])
index_a = torch.LongTensor([[0,0],[0,1]])
index_b = torch.LongTensor([[0,1,1],[1,0,0]])
print(t)
print(torch.gather(t,dim=1,index=index_a))
print(torch.gather(t,dim=0,index=index_b))
输出为:
>>tensor([[1., 2., 3.],
[4., 5., 6.]])
>>tensor([[1., 1.],
[4., 5.]])
>>tensor([[1., 5., 6.],
[4., 2., 3.]])
由于官网给的计算过程不太直观,下面给出较为直观的解释:
对于index_a,dim为1表示在第二个维度上进行聚合,索引为列号,[[0,0],[0,1]]表示结果的第一行取原数组第一行列号为[0,0]的数,也就是[1,1],结果的第二行取原数组第二行列号为[0,1]的数,也就是[4,5],这样就得到了输出的结果[[1,1],[4,5]]。
对于index_b,dim为0表示在第一个维度上进行聚合,索引为行号,[[0,1,1],[1,0,0]]表示结果的第一行第d(d=0,1,2)列取原数组第d列行号为[0,1,1]的数,也就是[1,5,6],类似的,结果的第二行第d列取原数组第d列行号为[1,0,0]的数,也就是[4,2,3],这样就得到了输出的结果[[1,5,6],[4,2,3]]
接下来以index_a为例直接用官网的式子计算一遍加深理解:
output[0,0] = input[0,index[0,0]] #1 = input[0,0]
output[0,1] = input[0,index[0,1]] #1 = input[0,0]
output[1,0] = input[1,index[1,0]] #4 = input[1,0]
output[1,1] = input[1,index[1,1]] #5 = input[1,1]
注
以下两种写法得到的结果是一样的:
r1 = torch.gather(t,dim=1,index=index_a)
r2 = t.gather(1,index_a)
补充:Pytorch中的torch.gather函数的个人理解
最近在学习pytorch时遇到gather函数,开始没怎么理解,后来查阅网上相关资料后大概明白了原理。
gather()函数
在pytorch中,gather()函数的作用是将数据从input中按index提出,我们看gather函数的的官方文档说明如下:
torch.gather(input, dim, index, out=None) → Tensor
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
可以看出,在gather函数中我们用到的主要有三个参数:
1)input:输入
2)dim:维度,常用的为0和1
3)index:索引位置
贴一段代码举例说明:
a=t.arange(0,16).view(4,4)
print(a)
index_1=t.LongTensor([[3,2,1,0]])
b=a.gather(0,index_1)
print(b)
index_2=t.LongTensor([[0,1,2,3]]).t()#tensor转置操作:(a)T=a.t()
c=a.gather(1,index_2)
print(c)
输出如下:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
tensor([[12, 9, 6, 3]])tensor([[ 0],
[ 5],
[10],
[15]])
在gather中,我们是通过index对input进行索引把对应的数据提取出来的,而dim决定了索引的方式。
在上面的例子中,a是一个4×4矩阵:
1)当维度dim=0,索引index_1为[3,2,1,0]时,此时可将a看成1×4的矩阵,通过index_1对a每列进行行索引:第一列第四行元素为12,第二列第三行元素为9,第三列第二行元素为6,第四列第一行元素为3,即b=[12,9,6,3];
2)当维度dim=1,索引index_2为[0,1,2,3]T时,此时可将a看成4×1的矩阵,通过index_1对a每行进行列索引:第一行第一列元素为0,第二行第二列元素为5,第三行第三列元素为10,第四行第四列元素为15,即c=[0,5,10,15]T;
来源:https://blog.csdn.net/qq_34392457/article/details/90206220
猜你喜欢
- 阅读上一篇:Freshow工具使用方法一. eval加密是在网马解密中最常见的,eval在jscript脚本中实际上是一个函数,简单可以理解
- Tensorflow二维、三维、四维矩阵运算(矩阵相乘,点乘,行/列累加)1. 矩阵相乘 根据矩阵相乘的匹配原则,左乘矩阵的列数要等于右乘矩
- 本文实例讲述了Python切片工具pillow用法。分享给大家供大家参考,具体如下:切片:使用切片将源图像分成许多的功能区域因为要对图片进行
- 如下所示:# -*- coding:utf-8 -*-import xlrdimport sysimport reimport jsondi
- 新浪微博需要登录才能爬取,这里使用m.weibo.cn这个移动端网站即可实现简化操作,用这个访问可以直接得到的微博id。分析新浪微博的评论获
- 浏览器:IE ,不支持firefoxfilter视觉滤镜的种类:Alpha(透明度) Blur(模糊) Chroma(指定颜色透明) Dro
- .xls格式 Office2003及以下版本 .xlsx格式Offi
- TF-IDF(term frequency–inverse document frequency)是一种用于信息检索与数据挖掘的常用加权技术
- OpenCV图像处理一、图像入门1.读取图像使用 cv.imread() 函数读取一张图像,图片应该在工作目录中,或者应该提供完整的图像路径
- # 比较两个字符串,如果不同返回第一个不相同的位置# 如果相同返回0def cmpstr(str1, str2): &
- 在工作中出于某些原因,我们可能需要将变量保存下来,这样下次就可以直接去赋值而不用重新执行某些重复耗时的操作了,这里我们用到了Python的p
- 前言很多前人曾说过,深度学习好比炼丹,框架就是丹炉,网络结构及算法就是单方,而数据集则是原材料,为了能够炼好丹,首先需要一个使用称手的丹炉,
- 当然这应该属于正常过滤手法,而还有一种过滤HTML标签的最终极手法,则是将一对尖括号及尖括号中的所有字符均替换不显示,该方法对于内容中必须描
- js 对url进行编码和解码三种编码和解码函数encodeURI和 decodeURI它着眼于对整个URL进行编码,因此除了常见的符号以外,
- 动画效果如下:GIF看起来可能会有点卡wxml<view class="confirm bubble">确定
- asyncio介绍熟悉c#的同学可能知道,在c#中可以很方便的使用 async 和 await 来实现异步编程,那么在p
- 一、使用python3做webervice接口测试的第三方库选择suds-jurko库,可以直接pip命令直接下载,也可以在pypi官网下载
- 简单的合并,本例是横向合并,纵向合并可以自行调整。import xlrd import xlwtimport shutil from xlu
- 从毕业实习算起,从事可用性方面的工作到现在已经5年了。在此记录笔者的一些所见所想,和大家讨论分享一下。用户研究在“以用户为中心”的界面设计方
- 一般事件事件浏览器支持描述onClickIE3|N2|O3鼠标点击事件,多用在某个对象控制的范围内的鼠标点击onDblClickIE4|N4