解析Pytorch中的torch.gather()函数
作者:xiaoliujun1999 发布时间:2023-01-29 23:44:40
标签:Pytorch,torch.gather(),函数
参数说明
以官方说明为例,gather()函数需要三个参数,输入input,维度dim,以及索引index
input必须为Tensor类型
dim为int类型,代表从哪个维度进行索引
index为LongTensor类型
举例说明
input=torch.tensor([[1,2,3],[4,5,6]]) #作为输入
index1=torch.tensor([[0,1,1],[0,1,1]]) #作为索引矩阵
# dim=0时,按列进行索引
print (torch.gather(input,dim=0,index=index1))
# dim=1时,按行进行索引
print (torch.gather(input,dim=1,index=index1))
结果如下图所示:
# 按列进行索引
tensor([[1, 5, 6],
[4, 2, 6]])
# 按行进行索引
tensor([[1, 2, 2],
[5, 4, 5]])
画图说明
官方文档
def gather(self, input, dim, index, *args, **kwargs):
For a 3-D tensor the output is specified by::
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
Args:
input (Tensor): the source tensor
dim (int): the axis along which to index
index (LongTensor): the indices of elements to gather
Example::
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
来源:https://blog.csdn.net/xiaoliujun1999/article/details/121292061
0
投稿
猜你喜欢
- 在Bootstrap fileinput中移除预览文件时可以通过配置initialPreviewConfig: [ { url:'d
- 本文实例讲述了Python实现的生产者、消费者问题。分享给大家供大家参考,具体如下:生产者、消费者问题,经典的线程同步问题:假设有一个缓冲池
- 1. 首先确认服务器的Federated引擎是否开启show engines;2. 如果Federated 未开启,则需要开启到MySQL的
- 本文实例讲述了Python中itertools模块用法,分享给大家供大家参考。具体分析如下:一般来说,itertools模块包含创建有效迭代
- 1、使用SHOW语句找出在服务器上当前存在什么数据库: mysql> SHOW DATABASES; +----------+ | D
- 一、本文使用的第三方包和工具python 3.8 谷歌浏览器selenium(3.141.0)(pip install
- chr(13) 是一个回车Chr(10) 是一个换行符chr
- 本文实例讲述了Python简单实现两个任意字符串乘积的方法。分享给大家供大家参考,具体如下:题目:给定两个任意数字组成的字符串,求乘积,字符
- 常用快捷键1、Ctrl + Enter:在下方新建行但不移动光标;2、Shift + Enter:在下方新建行并移到新行行首;3、Ctrl
- 有时候我们可能需要import另一个路径下的python文件,例如下面这个目录结构,我们想要在_train.py里import在networ
- I. Strict Mode阐述根据 mysql5.0以上版本 strict mode (STRICT_TRANS_TABLES) 的限制:
- 这篇文章主要介绍了Python JSON编解码方式原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- 项目需要把部分代码移植到 Golang , 之前用 Laravel 封装的写起来很舒服,在 Golang 里只能自动动手实现.一开始想的是使
- 通过pyshp库,可以读写Shapefile文件,查询相关信息,github地址为https://github.com/Geospatial
- 1、注释单行注释,使用#,#号后面的都是注射,例如#我是单行注释print("Hello Python world")多
- 1 非贪婪flag>>> re.findall(r"a(\d+?)", "a23b"
- 在实际开发中经常需要对前端传递的多个参数进行不为空校验,可以使用python提供的all()函数if not all([arg1, arg2
- 数据API数据集方法不会修改数据集,而是创建新数据集。可通过调用 map() 方法将转换应用于每个元素:dataset = dataset.
- 如何用ASP获知机器的网络配置?看看我们的例子:Option Explicit Dim WSHShell&nb
- 前言前面几篇简单介绍了一下前端与PHP的一些知识点,前端中表单提交是一个非常重要的模块,在本篇中我会介绍一些关于表单的知识,如果前面内容你掌