聊聊Pytorch torch.cat与torch.stack的区别
作者:Winner3 发布时间:2021-05-07 02:07:39
torch.cat()函数可以将多个张量拼接成一个张量。torch.cat()有两个参数,第一个是要拼接的张量的列表或是元组;第二个参数是拼接的维度。
torch.cat()的示例如下图1所示
图1 torch.cat()
torch.stack()函数同样有张量列表和维度两个参数。stack与cat的区别在于,torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。
torch.stack()的示例如下图2所示:
图2 torch.stack()
补充:torch.stack()的官方解释,详解以及例子
可以直接看最下面的【3.例子】,再回头看前面的解释
在pytorch中,常见的拼接函数主要是两个,分别是:
1、stack()
2、cat()
实际使用中,这两个函数互相辅助:关于cat()参考torch.cat(),但是本文主要说stack()。
函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。
形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面(矩阵)按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。
该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。
1. stack()
官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。
浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。
outputs = torch.stack(inputs, dim=?) → Tensor
参数
inputs : 待连接的张量序列。
注:python的序列数据只有list和tuple。
dim : 新的维度, 必须在0到len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。
2. 重点
函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等
----举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape
dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
不懂的看例子,再回过头看就懂了。
3. 例子
1.准备2个tensor数据,每个的shape都是[3,3]
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
2.测试stack函数
print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'选择的dim>len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
可以复制代码运行试试:拼接后的tensor形状,会根据不同的dim发生变化。
dim | shape |
---|---|
0 | [2, 3, 3] |
1 | [3,2, 3] |
2 | [3, 3,2] |
3 | 溢出报错 |
4. 总结
1、函数作用:
函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
2、存在意义:
在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。
函数存在意义?》》》
手写过RNN的同学,知道在循环神经网络中输出数据是:一个list,该列表插入了seq_len个形状是[batch_size, output_size]的tensor,不利于计算,需要使用stack进行拼接,保留–[1.seq_len这个时间步]和–[2.张量属性[batch_size, output_size]]。
来源:https://blog.csdn.net/winner3/article/details/102720816


猜你喜欢
- 思路1.将姓名和单号填入excel表格里面2.读取excel表格,将所有姓名存到ExeclName这个list中,单号存到ExeclId3.
- 系统:ubuntu18.04 x64GitHub:https://github.com/xingjidemimi/DjangoAPI.git
- 今天看到everything搜索速度秒杀windows自带的文件管理器,所以特地模仿everything实现了文件搜索以及打开对应文件的功能
- 正则表达式的定义在编写处理字符串的程时,经常会遇到在一段文本中查找符合某些规则的字符串的需求,正则表达式就是用于描述这些规则的工具,换句话说
- 特别是linux系统,装了多个python,有时候找不到python的绝对路径,有时候装了个django,又找不到django安装到哪里了。
- 在上篇文章《MySQL表结构变更,不可不知的Metadata Lock》中,我们介绍了MDL引入的背景,及基本概念,从“道”的层面知道了什么
- 本文实例为大家分享了python实现五子棋双人对弈的具体代码,供大家参考,具体内容如下我用的是pygame模块来制作窗口代码如下:# 1、引
- 对于使用Django框架开发的系统,当部署时设置settings.py文件中Debug=False时xadmin后台管理系统样式会丢失。【问
- 多条ROC曲线绘制函数def multi_models_roc(names, sampling_methods, colors, X_tes
- 脉冲星假信号频率的相对路径论证。首先看一下演示结果:实例代码:import numpy as npimport matplotlib.pyp
- 本文实例为大家分享了Python使用Pygame绘制时钟的具体代码,供大家参考,具体内容如下前提条件:需要安装pygame功能:1.初始化界
- 在上篇文章给大家介绍了:MySQL8.0.20安装教程及其安装问题详细教程 https://www.jb51.net/artic
- 本文实例讲述了PHP函数按引用传递参数及函数可选参数用法。分享给大家供大家参考,具体如下:一、函数按引用传递参数1. 代码<!DOCT
- 目录1、原始需求2、解决方案3、canal介绍、安装canal的工作原理架构安装4、验证1、原始需求既要同步原始全量数据,也要实时同步MyS
- 本文为大家分享了MySQL5.6安装教程,具体内容如下1. 下载MySQL2. 解压MySQL压缩包将以下载的MySQL压缩包解压到自定义目
- element-ui中el-form自定义验证需求在输入项目名称后,调用后端接口isNameOnly,若已存在,则效果如下图:1.先设置校验
- 之前一直在写有关scrapy爬虫的事情,今天我们看看使用scrapy如何把爬到的数据放在MySQL数据库中保存。有关python操作MySQ
- 我听说 Hooks 最近很火。讽刺的是,我想用一些关于 class 组件的有趣故事来开始这篇文章。你觉得如何?本文中这些坑对于你正常使用 R
- 从 Python 3 开始,str 类型代表着 Unicode 字符串。取决于编码的类型,一个 Unicode 字符可能会占 4 个字节,这
- 本文实例讲述了Python数据分析之双色球统计两个红和蓝球哪组合比例高的方法。分享给大家供大家参考,具体如下:统计两个红球和蓝球,哪个组合最