pytorch中的squeeze函数、cat函数使用
作者:zhuanse 发布时间:2022-03-27 14:32:24
1 squeeze(): 去除size为1的维度,包括行和列。
至于维度大于等于2时,squeeze()不起作用。
行、例:
>>> torch.rand(4, 1, 3)
(0 ,.,.) =
0.5391 0.8523 0.9260
(1 ,.,.) =
0.2507 0.9512 0.6578
(2 ,.,.) =
0.7302 0.3531 0.9442
(3 ,.,.) =
0.2689 0.4367 0.6610
[torch.FloatTensor of size 4x1x3]
>>> torch.rand(4, 1, 3).squeeze()
0.0801 0.4600 0.1799
0.0236 0.7137 0.6128
0.0242 0.3847 0.4546
0.9004 0.5018 0.4021
[torch.FloatTensor of size 4x3]
列、例:
>>> torch.rand(4, 3, 1)
(0 ,.,.) =
0.7013
0.9818
0.9723
(1 ,.,.) =
0.9902
0.8354
0.3864
(2 ,.,.) =
0.4620
0.0844
0.5707
(3 ,.,.) =
0.5722
0.2494
0.5815
[torch.FloatTensor of size 4x3x1]
>>> torch.rand(4, 3, 1).squeeze()
0.8784 0.6203 0.8213
0.7238 0.5447 0.8253
0.1719 0.7830 0.1046
0.0233 0.9771 0.2278
[torch.FloatTensor of size 4x3]
不变、例:
>>> torch.rand(4, 3, 2)
(0 ,.,.) =
0.6618 0.1678
0.3476 0.0329
0.1865 0.4349
(1 ,.,.) =
0.7588 0.8972
0.3339 0.8376
0.6289 0.9456
(2 ,.,.) =
0.1392 0.0320
0.0033 0.0187
0.8229 0.0005
(3 ,.,.) =
0.2327 0.6264
0.4810 0.6642
0.8625 0.6334
[torch.FloatTensor of size 4x3x2]
>>> torch.rand(4, 3, 2).squeeze()
(0 ,.,.) =
0.0593 0.8910
0.9779 0.1530
0.9210 0.2248
(1 ,.,.) =
0.7938 0.9362
0.1064 0.6630
0.9321 0.0453
(2 ,.,.) =
0.0189 0.9187
0.4458 0.9925
0.9928 0.7895
(3 ,.,.) =
0.5116 0.7253
0.0132 0.6673
0.9410 0.8159
[torch.FloatTensor of size 4x3x2]
2 cat函数
>>> t1=torch.FloatTensor(torch.randn(2,3))
>>> t1
-1.9405 1.2009 0.0018
0.9463 0.4409 -1.9017
[torch.FloatTensor of size 2x3]
>>> t2=torch.FloatTensor(torch.randn(2,2))
>>> t2
0.0942 0.1581
1.1621 1.2617
[torch.FloatTensor of size 2x2]
>>> torch.cat((t1, t2), 1)
-1.9405 1.2009 0.0018 0.0942 0.1581
0.9463 0.4409 -1.9017 1.1621 1.2617
[torch.FloatTensor of size 2x5]
补充:pytorch中 max()、view()、 squeeze()、 unsqueeze()
查了好多博客都似懂非懂,后来写了几个小例子,瞬间一目了然。
一、torch.max()
import torch
a=torch.randn(3)
print("a:\n",a)
print('max(a):',torch.max(a))
b=torch.randn(3,4)
print("b:\n",b)
print('max(b,0):',torch.max(b,0))
print('max(b,1):',torch.max(b,1))
输出:
a:
tensor([ 0.9558, 1.1242, 1.9503])
max(a): tensor(1.9503)
b:
tensor([[ 0.2765, 0.0726, -0.7753, 1.5334],
[ 0.0201, -0.0005, 0.2616, -1.1912],
[-0.6225, 0.6477, 0.8259, 0.3526]])
max(b,0): (tensor([ 0.2765, 0.6477, 0.8259, 1.5334]), tensor([ 0, 2, 2, 0]))
max(b,1): (tensor([ 1.5334, 0.2616, 0.8259]), tensor([ 3, 2, 2]))
max(a),用于一维数据,求出最大值。
max(a,0),计算出数据中一列的最大值,并输出最大值所在的行号。
max(a,1),计算出数据中一行的最大值,并输出最大值所在的列号。
print('max(b,1):',torch.max(b,1)[1])
输出:只输出行最大值所在的列号
max(b,1): tensor([ 3, 2, 2])
torch.max(b,1)[0], 只返回最大值的每个数
二、view()
a.view(i,j)表示将原矩阵转化为i行j列的形式
i为-1表示不限制行数,输出1列
a=torch.randn(3,4)
print(a)
输出:
tensor([[-0.8146, -0.6592, 1.5100, 0.7615],
[ 1.3021, 1.8362, -0.3590, 0.3028],
[ 0.0848, 0.7700, 1.0572, 0.6383]])
b=a.view(-1,1)
print(b)
输出:
tensor([[-0.8146],
[-0.6592],
[ 1.5100],
[ 0.7615],
[ 1.3021],
[ 1.8362],
[-0.3590],
[ 0.3028],
[ 0.0848],
[ 0.7700],
[ 1.0572],
[ 0.6383]])
i为1,j为-1表示不限制列数,输出1行
b=a.view(1,-1)
print(b)
输出:
tensor([[-0.8146, -0.6592, 1.5100, 0.7615, 1.3021, 1.8362, -0.3590,
0.3028, 0.0848, 0.7700, 1.0572, 0.6383]])
i为-1,j为2表示不限制行数,输出2列
b=a.view(-1,2)
print(b)
输出:
tensor([[-0.8146, -0.6592],
[ 1.5100, 0.7615],
[ 1.3021, 1.8362],
[-0.3590, 0.3028],
[ 0.0848, 0.7700],
[ 1.0572, 0.6383]])
i为-1,j为3表示不限制行数,输出3列
i为4,j为3表示输出4行3列
b=a.view(-1,3)
print(b)
b=a.view(4,3)
print(b)
输出:
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
tensor([[-0.8146, -0.6592, 1.5100],
[ 0.7615, 1.3021, 1.8362],
[-0.3590, 0.3028, 0.0848],
[ 0.7700, 1.0572, 0.6383]])
三、
1.torch.squeeze()
压缩矩阵,我理解为降维
a.squeeze(i) 压缩第i维,如果这一维维数是1,则这一维可有可无,便可以压缩
import torch
a=torch.randn(1,3,4)
print(a)
b=a.squeeze(0)
print(b)
c=a.squeeze(1)
print(c
输出:
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
一页三行4列的矩阵
第0维为1,则可以通过squeeze(0)删掉,转化为三行4列的矩阵
tensor([[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]])
第1维不为1,则不可以压缩
tensor([[[ 0.4627, 1.6447, 0.1320, 2.0946],
[-0.0080, 0.1794, 1.1898, -1.2525],
[ 0.8281, -0.8166, 1.8846, 0.9008]]])
2.torch.unsqueeze()
unsqueeze(i) 表示将第i维设置为1
对压缩为3行4列后的矩阵b进行操作,将第0维设置为1
c=b.unsqueeze(0)
print(c)
输出一个一页三行四列的矩阵
tensor([[[ 0.0661, -0.2386, -0.6610, 1.5774],
[ 1.2210, -0.1084, -0.1166, -0.2379],
[-1.0012, -0.4363, 1.0057, -1.5180]]])
将第一维设置为1
c=b.unsqueeze(1)
print(c)
输出一个3页,一行,4列的矩阵
tensor([[[-1.0067, -1.1477, -0.3213, -1.0633]],
[[-2.3976, 0.9857, -0.3462, -0.3648]],
[[ 1.1012, -0.4659, -0.0858, 1.6631]]])
另外,squeeze、unsqueeze操作不改变原矩阵
来源:https://blog.csdn.net/abc781cba/article/details/79663190


猜你喜欢
- 前言:很多人都在使用mysql数据库,但是很少有人能够说出来整个sql语句的执行过程是怎样的,如果不了解执行过程的话,就很难进行sql语句的
- 今天在做项目时,遇到了需要创建JavaScript对象的情况。所以Bing了一篇老外写的关于3种创建JavaScript对象的文章,看后跟着
- 最近使用Python调用百度的REST API实现语音识别,但是百度要求音频文件的压缩方式只能是pcm(不压缩)、wav、opus、spee
- 具体代码如下所示:<SCRIPT LANGUAGE="JavaScript"> <!-- var dn
- 创建表时创建外键创建两个表格,一个名为class,create table classes(id int not null primary
- [编者注:]提起数据库,第一个想到的公司,一般都会是Oracle(即甲骨文公司)。Oracle在数据库领域一直处于领先地位。Oracle关系
- 异常,不应该存在,但是我们有时候会遇到这样的情况,比如我们监控服务器的时候,每一秒去采集一次信息,那么有一秒没有采集到我们想要的信息,但是下
- 我就废话不多说了,直接上代码吧!import torchimport torch.nn.functional as Fimport nump
- 1. rangerange是python内置的一个类,该类型表示一个不可改变(immutable)的数字序列,常常用于在for循环中迭代一组
- <?php $monthoneday=date("Ym")."01"; $oneweekday
- 如果我们希望把一个网站的更新实时发布到另一个网站上,最好的方法是通过 RSS 进行转载。如果只是需要简单的对更新的条目做个提示的话,使用 J
- 不说什么,先上代码这里先求解形如的微分方程1.欧拉法def eluer(rangee,h,fun,x0,y0):
- 本文实例为大家分享了js贪吃蛇游戏的相关代码,供大家参考,具体内容如下<!DOCTYPE html><html lang=
- 问题:这里只解决一个问题,到底什么是Access?设计一个数据库管理系统,用access在access里面设计好表,查询,然后再用vb做窗体
- 目录01 安装02 剪辑01 安装对视频进行批量剪辑,需要三个库,分别是Moviepy库和Pathlib库,还有Tkinter库。首先我们对
- 注:此处“重复”非完全重复,意为某字段数据重复HZT表结构IDintTitlenvarchar(50)AddDatedatetime数据一.
- 1、表示乘号2、表示倍数,例如:def T(msg,time=1): print((msg+' ')*time)
- 1.plt.pie()饼图 常常用来显示 整体中各部分所占的比例,在python-matplotlib库中通过plt.pie()方法来实现。
- Case:需要给一个现有的shp数据创建一个字段,并将属性表中原有的一个文本类型的属性转换为整型后填入新创建的字段。Problem:新字段创
- 在Oracle数据库中,如何查找,定位一张表最后一次的DML操作的时间呢? 方式有三种,不过都有一些局限性,下面简单的解析、总结一下。1:使