pytorch下的unsqueeze和squeeze的用法说明
作者:York1996 发布时间:2023-07-16 14:01:41
#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加
import torch
a=torch.rand(2,3,1)
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1])
print(a.size()) #torch.Size([2, 3, 1])
print(a.squeeze().size()) #torch.Size([2, 3])
print(a.squeeze(0).size()) #torch.Size([2, 3, 1])
print(a.squeeze(-1).size()) #torch.Size([2, 3])
print(a.size()) #torch.Size([2, 3, 1])
print(a.squeeze(-2).size()) #torch.Size([2, 3, 1])
print(a.squeeze(-3).size()) #torch.Size([2, 3, 1])
print(a.squeeze(1).size()) #torch.Size([2, 3, 1])
print(a.squeeze(2).size()) #torch.Size([2, 3])
print(a.squeeze(3).size()) #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
print(a.unsqueeze().size()) #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size()) #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size()) #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size()) #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size()) #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size()) #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size()) #torch.Size([2, 3])
补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结
学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。
1、unsqueeze()和squeeze()
torch.unsqueeze(input, dim,out=None) → Tensor
unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:
a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])
也可以直接使用链式编程:
a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])
tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。
torch.squeeze(input, dim=None, out=None) → Tensor
squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。
同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。
a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])
图片中的维度信息就不一样,红框中的括号层数不同。
注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。
a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])
2、expand()
这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。
torch.Tensor.expand(*sizes) → Tensor
返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。
a=torch.randn(1,1,3,768)
print(a)
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])
可以看到b和c的维度是一样的
第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。
3、repeat()
repeat(*sizes)
沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。
a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])
b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])
c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])
a:
b
c
4、view()
tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:
word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)
当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768
另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。
word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)
结果如下:使用-1以后,就会自动得到其他维度维。
需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。
5、cat()
cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。
其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列
dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接
a=torch.randn(4,3)
b=torch.randn(4,3)
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)
还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。
tensors=[]
for i in range(10):
tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)
结果:
torch.Size([40, 3])
torch.Size([4, 30])
以上为个人经验,希望能给大家一个参考,也希望大家多多支持
来源:https://blog.csdn.net/york1996/article/details/81875508


猜你喜欢
- 非阻塞IO(non-blocking IO)Linux下,可以通过设置socket使其变为non-blocking。当对一个non-bloc
- 在Visual Studio 中使用git——什么是Git(一)如果要使用git进行版本管理,其实使用git命令行工具就完全足够了,图形化工
- 本教程为大家分享了win10下Python环境安装配置教程,供大家参考,具体内容如下1.在https://www.python.org/do
- 最近研究验证码识别,需要生成大量验证码,最方便的是使用captcha库来生成验证码,网上代码仅仅使用默认设置,但是它还有很多参数可以设定,于
- 在MySQL经历了2008年Sun的收购和2009年Oracle收购Sun的过程中,基本处于停滞发展的情况,在可以预见的未来,MySQL是肯
- 首先你要确定错误的原因: 让IE显示详细的出错信息: 菜单--工具--Internet选项--高级--显示友好的HTTP错误信息,去掉这个选
- 一、说明前面我们说了mysql的安装配置,mysql语句使用以及备份恢复mysql数据;本次要介绍的是mysql的主从复制,读写分离;及高可
- 一、数据类型在tf中,数据类型有整型(默认是int32),浮点型(默认是float32),以及布尔型,字符串。二、数据类型信息①.devic
- 一、如何实现可迭代对象和迭代器对象?实际案例某软件要求从网络抓取各个城市气味信息,并其次显示:北京: 15 ~ 20 天津: 17 ~ 22
- numpy模块下的median作用为: 计算沿指定轴的中位数返回数组元素的中位数其函数接口为:median(a, axis=None, ou
- 前言:《flappy bird》是一款由来自越南的独立游戏开发者Dong Nguyen所开发的作品,游戏于2013年5月24日上线,并在20
- <%@ Page Language="VB" %> <!DOCTYPE html PUBLIC &qu
- 本文实例讲述了django框架自定义模板标签(template tag)操作。分享给大家供大家参考,具体如下:django 提供了丰富的模板
- 1、善用拖放技术 我们在使用Dreamweaver编辑网页的时候,经常需要插入一些图象什么的,假设要插入的图象很多,按照常规方法来操作就显得
- 本文记录了mysql 5.7.23安装教程,供大家参考。1、首先进入官网下载mysql安装包,官网地址可以选择自己想要的版本,默认是8.0,
- <!-- #include file="conn.asp" -->
- python是支持多线程的,并且是native的线程。主要是通过thread和threading这两个模块来实现的。thread是比较底层的
- 一、find_element_by_id()find_element_by_id()1.从上面定位到的元素属性中,可以看到有个id属性:id
- 网上资料结合自己的操作整理出的一套靠谱的彻底卸载Oracle 11g的步骤!(Win7),具体内容详情如下所示:1:停掉所有Oracle相关
- 数据可视化的时候,常常需要将多个子图放在同一个画板上进行比较,python 的matplotlib包下的subplot可以帮助完成子功能。p