Pytorch中torch.stack()函数的深入解析
作者:cv_lhp 发布时间:2021-06-17 18:39:09
一. torch.stack()函数解析
1. 函数说明:
1.1 官网:torch.stack(),函数定义及参数说明如下图所示:
1.2 函数功能
沿一个新维度对输入一系列张量进行连接,序列中所有张量应为相同形状,stack 函数返回的结果会新增一个维度。也即是把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度上面进行堆叠。
1.3 参数列表
tensors :为一系列输入张量,类型为turple和List
dim :新增维度的(下标)位置,当dim = -1时默认最后一个维度;范围必须介于 0 到输入张量的维数之间,默认是dim=0,在第0维进行连接
返回值:输出新增维度后的张量
2. 代码举例
2.1 dim = 0 : 在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)
import torch
#二维输入张量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在行上进行组合(输入张量为一维,输出张量为两维)
print(a)
print(b)
print(c)
输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 2, 3],
[11, 22, 33]])
2.2 dim = 1 :在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)
import torch
#二维输入张量a,b
a = torch.tensor([1, 2, 3])
b = torch.tensor([11, 22, 33])
c = torch.stack([a, b],dim=1)#在第1维进行连接,相当于在对应行上面对列元素进行组合(输入张量为一维,输出张量为两维)
print(a)
print(b)
print(c)
输出结果如下:
tensor([1, 2, 3])
tensor([11, 22, 33])
tensor([[ 1, 11],
[ 2, 22],
[ 3, 33]])
2.3 dim=0:表示在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维),注意:此处输入张量维度为二维,因此dim最大只能为2。
import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b],dim=0)#在第0维进行连接,相当于在通道维度上进行组合(输入张量为两维,输出张量为三维)
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],[[11, 22, 33],
[44, 55, 66],
[77, 88, 99]]])
2.4 dim=1:表示在第1维进行连接,相当于对相应通道中每个行进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。
import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 1)#在第1维进行连接,相当于对相应通道中每个行进行组合
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 2, 3],
[11, 22, 33]],[[ 4, 5, 6],
[44, 55, 66]],[[ 7, 8, 9],
[77, 88, 99]]])
2.5 dim=2:表示在第2维进行连接,相当于对相应行中每个列元素进行组合,注意:此处输入张量维度为二维,因此dim最大只能为2。
import torch
#二维输入张量a,b
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[11, 22, 33], [44, 55, 66], [77, 88, 99]])
c = torch.stack([a, b], 2)#在第2维进行连接,相当于对相应行中每个列元素进行组合
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
tensor([[11, 22, 33],
[44, 55, 66],
[77, 88, 99]])
tensor([[[ 1, 11],
[ 2, 22],
[ 3, 33]],[[ 4, 44],
[ 5, 55],
[ 6, 66]],[[ 7, 77],
[ 8, 88],
[ 9, 99]]])
2.6 dim=3:表示在第3维进行连接,相当于对相应行中每个列元素进行组合(输入维度大小为3维,因此dim=3最后一维始终代表为列),注意:此处输入张量维度为三维,因此dim最大只能为3。
import torch
#三维输入张量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 3)#表示在第3维进行连接,相当于对相应行中每个列元素进行组合(最后一维是第三维,始终代表为列)
print(a)
print(b)
print(c)
输出结果如下所示:
tensor([[[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9]],[[10, 20, 30],
[40, 50, 60],
[70, 80, 90]]])
tensor([[[ 11, 22, 33],
[ 44, 55, 66],
[ 77, 88, 99]],[[110, 220, 330],
[440, 550, 660],
[770, 880, 990]]])
tensor([[[[ 1, 11],
[ 2, 22],
[ 3, 33]],[[ 4, 44],
[ 5, 55],
[ 6, 66]],[[ 7, 77],
[ 8, 88],
[ 9, 99]]],
[[[ 10, 110],
[ 20, 220],
[ 30, 330]],[[ 40, 440],
[ 50, 550],
[ 60, 660]],[[ 70, 770],
[ 80, 880],
[ 90, 990]]]])
2.7 dim=4 (错误维度:因为此处输入张量维度为三维,所以dim最大只能为3,此处维度为4,因此会报错)
import torch
#三维输入张量a,b
a = torch.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],[[10, 20, 30], [40, 50, 60], [70, 80, 90]]])
b = torch.tensor([[[11, 22, 33], [44, 55, 66], [77, 88, 99]], [[110, 220, 330], [440, 550, 660], [770, 880, 990]]])
c = torch.stack([a, b], 4)
print(a)
print(b)
print(c)
输出错误:
IndexError: Dimension out of range (expected to be in range of [-4, 3], but got 4)
来源:https://blog.csdn.net/flyingluohaipeng/article/details/125034358


猜你喜欢
- 一、前言基于Mediapipe+Opencv实现手势检测,想实现一下姿态识别的时候,感觉手势识别也蛮重要的就过来顺便实现一下。下面是一些国内
- 一个动态载入asp树源码。把 node.htc, style.css 保存与 css 目录下. index.asp subtree.asp
- 1.Django的简介Django是一个基于MVC构造的框架。但是在Django中,控制器接受用户输入的部分由框架自行处理,所以 Djang
- 你好,我是林骥。斜率图,可以快速展现两组数据之间各维度的变化,特别适合用于对比两个时间点的数据。比如说,为了对比分析某产品不同功能的用户满意
- 恭喜您,您中奖了,你的中奖码是(请牢记,领奖需要):XXXXXXXXXXX然后用户输入XXXXXXXXXXX,简单验证后就可以领奖了。你使用
- 在VBScript中,有一个On Error Resume Next语句,它使脚本解释器忽略运行期错误并继续脚本代码的执行。接着该脚本可以检
- jieba.cut与jieba.lcut的区别jieba.cut生成的是一个生成器,generator,也就是可以通过for循环来取里面的每
- 本文实例讲述了python实现将英文单词表示的数字转换成阿拉伯数字的方法。分享给大家供大家参考。具体实现方法如下:import re_kno
- 主页上的鼠标是不是就只有箭头和小手两种模样呢?如果鼠标移到“帮助”等字样上时,形状就变成求助的问号;鼠标移到可能需要较长时间等待的超链接时,
- 如何让图片自动缩放以适合界面大小,拿出你的Editplus,打开c_function.asp文件,找到UBBCode函数,在第417行有如下
- MJML是一种现代的电子邮件工具,使开发人员可以在所有设备和邮件客户端上创建美观、响应迅速的出色电子邮件。这种标记语言是为了减少编写响应式电
- #mode operand create truncate#read < #write >&nbs
- 在web开发中经常遇到多关键词对对个字段查询,我一般是通过动态数组来实现的。当然多个关键词的一般是用空格或,隔开,我这几假设多个
- 本文教大家调用电脑摄像头进行实时人脸+眼睛识别+微笑识别,供大家参考,具体内容如下一、调用电脑摄像头进行实时人脸+眼睛识别# 调用电脑摄像头
- 一、外键设置方法1、在MySQL中,为了把2个表关联起来,会用到2个重要的功能:外键(FOREIGN KEY)和连接(JOIN)。外键需要在
- #!/usr/bin/perluse strict;use warnings;use re 'debug';sub test
- 本文实例讲述了Flask框架学习笔记之消息提示与异常处理操作。分享给大家供大家参考,具体如下:flask通过flash方法来显示提示消息:f
- MySQL DATE_FORMAT函数简介要将日期值格式化为特定格式,请使用DATE_FORMAT函数。 DATE_FORMAT函数的语法如
- 高级特性切片操作:对list,tuple元素进行截取操作,非常简便。L[0:3],L[:3] 截取前3个元素。L[1:3] 从1开始截取2个
- 如下所示:df = df[df['cityname']==u'北京市']记得,如果用的python2,一定要