网络编程
位置:首页>> 网络编程>> 网络编程>> pytorch下的unsqueeze和squeeze的用法说明

pytorch下的unsqueeze和squeeze的用法说明

作者:York1996  发布时间:2023-07-16 14:01:41 

标签:pytorch,unsqueeze,squeeze

#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉

#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加


import torch
a=torch.rand(2,3,1)      
print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze().size())    #torch.Size([2, 3])
print(a.squeeze(0).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-1).size())   #torch.Size([2, 3])
print(a.size())         #torch.Size([2, 3, 1])
print(a.squeeze(-2).size())   #torch.Size([2, 3, 1])
print(a.squeeze(-3).size())   #torch.Size([2, 3, 1])
print(a.squeeze(1).size())   #torch.Size([2, 3, 1])
print(a.squeeze(2).size())   #torch.Size([2, 3])
print(a.squeeze(3).size())   #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
print(a.unsqueeze().size())   #TypeError: unsqueeze() missing 1 required positional arguments: "dim"
print(a.unsqueeze(-3).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(-2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(-1).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(0).size())  #torch.Size([1, 2, 3, 1])
print(a.unsqueeze(1).size())  #torch.Size([2, 1, 3, 1])
print(a.unsqueeze(2).size())  #torch.Size([2, 3, 1, 1])
print(a.unsqueeze(3).size())  #torch.Size([2, 3, 1, 1])
print(torch.unsqueeze(a,3))
b=torch.rand(2,1,3,1)
print(b.squeeze().size())    #torch.Size([2, 3])

补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结

学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。

1、unsqueeze()和squeeze()


torch.unsqueeze(input, dim,out=None) → Tensor

unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:


a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(0)
print(a.shape) #torch.Size([1, 768])
a = a.unsqueeze(2)
print(a.shape) #torch.Size([1, 768, 1])

也可以直接使用链式编程:


a=torch.randn(768)
print(a.shape) # torch.Size([768])
a=a.unsqueeze(1).unsqueeze(0)
print(a.shape) #torch.Size([1, 768, 1])

tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。


torch.squeeze(input, dim=None, out=None) → Tensor

squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。

同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。


a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
a=a.squeeze()
print(a)
print(a.shape) #torch.Size([2, 768])

pytorch下的unsqueeze和squeeze的用法说明

图片中的维度信息就不一样,红框中的括号层数不同。

注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。


a=torch.randn(2,768)
print(a.shape) #torch.Size([2, 768])
a=a.squeeze()
print(a.shape) #torch.Size([2, 768])

2、expand()

这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。


torch.Tensor.expand(*sizes) → Tensor

返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。


a=torch.randn(1,1,3,768)
print(a)
print(a.shape) #torch.Size([1, 1, 3, 768])
b=a.expand(2,-1,-1,-1)
print(b)
print(b.shape) #torch.Size([2, 1, 3, 768])
c=a.expand(2,1,3,768)
print(c.shape) #torch.Size([2, 1, 3, 768])

可以看到b和c的维度是一样的

pytorch下的unsqueeze和squeeze的用法说明

第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。

3、repeat()

repeat(*sizes)

沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。


a=torch.randn(2,1,768)
print(a)
print(a.shape) #torch.Size([2, 1, 768])
b=a.repeat(1,2,1)
print(b)
print(b.shape) #torch.Size([2, 2, 768])
c=a.repeat(3,3,3)
print(c)
print(c.shape) #torch.Size([6, 3, 2304])

b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])

c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])

a:

pytorch下的unsqueeze和squeeze的用法说明

b

pytorch下的unsqueeze和squeeze的用法说明

c

pytorch下的unsqueeze和squeeze的用法说明

4、view()

tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:


word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(8,6,768)
print(new_word_embedding.shape)

当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768

另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。


word_embedding=torch.randn(16,3,768)
print(word_embedding.shape)
new_word_embedding=word_embedding.view(-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(1,-1)
print(new_word_embedding.shape)
new_word_embedding=word_embedding.view(-1,768)
print(new_word_embedding.shape)

结果如下:使用-1以后,就会自动得到其他维度维。

pytorch下的unsqueeze和squeeze的用法说明

需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。

5、cat()

cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。

其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列

dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接


a=torch.randn(4,3)
b=torch.randn(4,3)
c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3])
print(c.shape)
d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6])
print(d.shape)

还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。


tensors=[]
for i in range(10):
 tensors.append(torch.randn(4,3))
a=torch.cat(tensors,dim=0) #torch.Size([40, 3])
print(a.shape)
b=torch.cat(tensors,dim=1) #torch.Size([4, 30])
print(b.shape)

结果:


torch.Size([40, 3])
torch.Size([4, 30])

以上为个人经验,希望能给大家一个参考,也希望大家多多支持

来源:https://blog.csdn.net/york1996/article/details/81875508

0
投稿

猜你喜欢

  • 关于中大型开发b/s开发中的缓存(cache),我的一些看法,有不正确的或者是有笔误的地方,请指正。thanks首先,应该了解基本的,对于缓
  • 说到这个话题,我们有个产品叫群组,为什么人们需要群组?简单说,群组就是个圈子,是有共同爱好和话题的人群聚在一起讨论、分享的地方。这个产品的诞
  • 对url进行编码在服务器端我们可以使用asp中的server.urlencode,很方便实现。如:<% ss="asp之家欢
  • 文章背景:某天,我的一个同事给我看了CSDN上面的一篇关于编程语言排行榜的文章,里面我看到VB还是排名很不错的,我就说,asp(vbscri
  • 怎样才能将在表A取得的数据插入另一个表B中?(1)对于表A和表B两个表结构完全相同的话〔字段个数,相应字段的类型等等〕,可以使用 inser
  • 试了一下,xmlDoc.save()行不同,就试着用fso做了出来。整理一下,供大家discuss。由于用js操作本地xml文件之后save
  • 在进行WEB标准网页设计时,必不可少的是写入大量的CSS语法,一般情况下我们可以通过Dreamweaver软件的“CSS样式”面板自动生成相
  • 从今天起,我将陆续将 ppk on JavaScript 的读书心得发布到这个blog上。ppk是我所景仰的一位web开发者,原因无它,只是
  • 一、输出指令ASP的输出指令<% =expression %>显示表达式的值。这个输出指令等同于使用Resp
  • 在Internet上我们每天都会遇到数不清的表单,也看到其中大部分并没有限制用户多次提交同一个表单。缺乏这种限制有时候会产生某些预料不到的结
  • 在js中调用asp文件的方法很简单,我们可以用在静态页面的点击数统计,虽然静态页面不支持asp程序,但是我们可以使用js调用,来变相的达到这
  • 非常好的边框样式设置工具,使用该工具您可以很方便的为DIV设置简单的边框样式,如果放在DW中会更好。会制作DW插件的高手,请帮忙制作成DW插
  • 假设你想找到本书中的某一个句子。你可以一页一页地逐页搜索,但这会花很多时间。而通过使用本书的索引,你可以很快地找到你要搜索的主题。表的索引与
  • #mkdir /mysqldata2、创建/usr/sbin/bakmysql文件#nano /usr/sbin/bakmysql输入:#!
  • 在使用SQL Server 的过程,中由于经常需要从多个不同地点将数据集中起来或向多个地点复制数据,所以数据的导出,导入是极为常见的操作.我
  • 嗯,你可以说我很无聊。最近疯狂加班,今天才得以有时间搞一个CSS的像素图来消遣休息下。先看效果:运行代码框<!DOCTYPE html
  • 题目:用 JavaScript 代码实现空位补零,比如 pad(12, 3) => 012实现一:/* 平淡无奇法 */functio
  • 指定结果集的列名AS 子句可用来更改结果集列名或为导出列指定名称。当结果集列由对表或视图中的列的引用进行定义时,结果集列的名称与所引用列的名
  • PHP Date/Time 简介Date/Time 函数允许您从 PHP 脚本运行的服务器上获取日期和时间。您可以使用 Date/Time
  • Example.asp<%@LANGUAGE="VBSCRIPT" CODEPAGE="65001&qu
手机版 网络编程 asp之家 www.aspxhome.com