PyTorch中的squeeze()和unsqueeze()解析与应用案例
作者:易烊千蝈 发布时间:2022-09-22 23:04:49
附上官网地址:
https://pytorch.org/docs/stable/index.html
1.torch.squeeze
squeeze的用法主要就是对数据的维度进行压缩或者解压。
先看torch.squeeze()
这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行。squeeze(a)就是将a中所有为1的维度删掉。不为1的维度没有影响。a.squeeze(N) 就是去掉a中指定的维数为一的维度。还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度。
换言之:
表示若第arg维的维度值为1,则去掉该维度,否则tensor不变。(即若tensor.shape()[arg] == 1,则去掉该维度)
例如:
一个维度为2x1x2x1x2的tensor,不用去想它长什么样儿,squeeze(0)就是不变,squeeze(1)就是变成2x2x1x2。(0是从最左边的维度算起的)
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
2.torch.unsqueeze
torch.unsqueeze()
这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度,比如原本有个三行的数据(3),在0的位置加了一维就变成一行三列(1,3)。a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度。还有一种形式就是b=torch.squeeze(a,N) a就是在a中指定位置N加上一个维数为1的维度。
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1, 2, 3, 4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],
[ 2],
[ 3],
[ 4]])
3.例子
给一个使用上述两个函数,并进行一次卷积的例子:
from torchvision.transforms import ToTensor
import torch as t
from torch import nnimport cv2
import numpy as np
import cv2
to_tensor = ToTensor()
# 加载图像
lena = cv2.imread('lena.jpg', cv2.IMREAD_GRAYSCALE)
cv2.imshow('lena', lena)
# input = to_tensor(lena) 将ndarray转换为tensor,自动将[0,255]归一化至[0,1]。
input = to_tensor(lena).unsqueeze(0)
# 初始化卷积参数
kernel = t.ones(1, 1, 3, 3)/-9
kernel[:, :, 1, 1] = 1
conv = nn.Conv2d(1, 1, 3, 1, padding=1, bias=False)
conv.weight.data = kernel.view(1, 1, 3, 3)
# 输出
out = conv(input)
out = out.squeeze(0)
print(out.shape)
out = out.unsqueeze(3)
print(out.shape)
out = out.squeeze(0)
print(out.shape)
out = out.detach().numpy()# 缩放到0~最大值
cv2.normalize(out, out, 1.0, 0, cv2.NORM_INF)
cv2.imshow("lena-result", out)
cv2.waitKey()
结果图如下:
references:
[1] 陈云.深度学习框架之PyTorch入门与实践.北京:电子工业出版社,2018.
来源:https://blog.csdn.net/weixin_39490300/article/details/123464027


猜你喜欢
- CUDA的线程与块GPU从计算逻辑来讲,可以认为是一个高并行度的计算阵列,我们可以想象成一个二维的像围棋棋盘一样的网格,每一个格子都可以执行
- 在良好的数据库设计基础上,能有效地使用索引是SQL Server取得高性能的基础,SQL Server采用基于代价的优化模型,它对每一个提交
- 方法一使用以下流式代码,无论下载文件的大小如何,Python 内存占用都不会增加:def download_file(url):  
- 原来sql还有个stuff的函数,很强悍。 一个列的格式是单引号后面跟着4位的数字,比如'0003,'0120,'4
- 一、变量1.变量•指在程序执行过程中,可变的量;•定义一个变量,就会伴随有3个特征,分别是内存ID、数据类型和变量值。•其他语言运行完之前,
- <?php/* Function Written by Nelson Neoh @3/2004. For th
- 本文实例为大家分享了python实现多张图片拼接成大图的具体代码,供大家参考,具体内容如下上次爬取了马蜂窝的游记图片,并解决了PIL模块的导
- fromkeys()方法类似于列表的浅拷贝首先用该方法创建一个字典dict_ = dict.fromkeys(('a',
- 大家都知道,IE中的现代事件绑定(attachEvent)与W3C标准的(addEventListener)相比存在很多问题,例如:内存泄漏
- 对开区间和闭区间的理解对于开区间,本身已经不包含两端点值,所以根本满足不了连续的第一个要求,所以要说某一开区间连续,我们说是函数在这一开区间
- 在ORACLE中,我们可以通过file_id(file#)与block_id(block#)去定位一个数据库对象(object)。例如,我们
- 初入Python,一开始就被她简介的语法所吸引,代码简洁优雅,之前在C#里面打开文件写入文件等操作相比Python复杂多了,而Python打
- ------谁正在访问数据库?Select c.sid, c.serial#,c.username,a.object_id,b.
- 1、问题在使用Python中pandas读取csv文件时,由于文件编码格式出现以下问题:Traceback (most recent cal
- BigPipe 是 Facebook 开发的优化网页加载速度的技术。网上几乎没有用 node.js 实现的文章,实际上,不止于 node.j
- 一、Beautiful Soup库简介BeautifulSoup4 是一个 HTML/XML 的解析器,主要的功能是解析和提取 HTML/X
- 运算符的优先级和关联性运算符的优先级和关联性: 运算符的优先级和关联性决定了运算符的优先级。运算符优先级这用于具有多个具有不同优先级的运算符
- 所需库的安装很多人问Pytorch要怎么可视化,于是决定搞一篇。tensorboardX==2.0tensorflow==1.13.2由于t
- Pytorch转ONNX的意义一般来说转ONNX只是一个手段,在之后得到ONNX模型后还需要再将它做转换,比如转换到TensorRT上完成部
- 目录OpenCV先决条件我们会在本文中涵盖7个主题读,写和显示图像imread():imshow():imwrite():读取视频并与网络摄