Pytorch中torch.repeat_interleave()函数使用及说明
作者:cv_lhp 发布时间:2022-06-17 04:18:53
torch.repeat_interleave()函数解析
1.函数说明
官网:torch.repeat_interleave(),函数说明如下图所示:
2. 函数原型
torch.repeat_interleave(input, repeats, dim=None) → Tensor
3. 函数功能
沿着指定的维度重复张量的元素
4. 输入参数
1)input (类型:torch.Tensor):输入张量
2)repeats(类型:int或torch.Tensor):每个元素的重复次数
3)dim(类型:int)需要重复的维度。默认情况下dim=None,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复repeats次,并返回重复后的张量。
5. 注意
1) 如果不指定dim,则默认将输入张量扁平化(维数是1,因此这时repeats必须是一个数,不能是数组),并且返回一个扁平化的输出数组。
2) 返回的数组与输入数组维数相同,并且除了给定的维度dim,其他维度大小与输入数组相应维度大小相同
3) repeats:如果传入数组,则必须是tensor格式。并且只能是一维数组,数组长度与输入数组input的dim维度大小相同
6. 代码例子
6.1 输入一维张量,不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。
a = torch.randn(5)
a,torch.repeat_interleave(a,2)
输出结果如下所示:
(tensor([ 0.4030, -1.1536, -2.4513, 1.1454, -0.8818]),
tensor([ 0.4030, 0.4030, -1.1536, -1.1536, -2.4513, -2.4513, 1.1454, 1.1454,
-0.8818, -0.8818]))
6.2 输入二维张量,不指定dim,重复次数为2次,表示将把给定的输入张量展平(flatten)为向量,然后将每个元素重复2次,并返回重复后的张量。
a = torch.randn(3,2)
a,a.repeat_interleave(2)
输出结果如下:
(tensor([[-1.03, -0.32],
[ 0.43, 0.78],
[ 0.91, -0.11]]),
tensor([-1.03, -1.03, -0.32, -0.32, 0.43, 0.43, 0.78, 0.78, 0.91, 0.91,
-0.11, -0.11]))
6.3 输入二维张量,指定dim=0,重复次数为3次,表示把输入张量每行元素重复3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=0)
输出结果如下:
(tensor([[ 0.14, 1.47],
[-1.52, -0.62],
[-0.24, -0.27]]),
tensor([[ 0.14, 1.47],
[ 0.14, 1.47],
[ 0.14, 1.47],
[-1.52, -0.62],
[-1.52, -0.62],
[-1.52, -0.62],
[-0.24, -0.27],
[-0.24, -0.27],
[-0.24, -0.27]]))
6.4 输入二维张量,指定dim=1,重复次数为3次,表示把输入张量每列元素重复3次
a = torch.randn(3,2)
a,torch.repeat_interleave(a,3,dim=1)
输出结果如下:
(tensor([[-0.81, 0.56],
[-2.41, -0.56],
[ 0.38, -0.90]]),
tensor([[-0.81, -0.81, -0.81, 0.56, 0.56, 0.56],
[-2.41, -2.41, -2.41, -0.56, -0.56, -0.56],
[ 0.38, 0.38, 0.38, -0.90, -0.90, -0.90]]))
6.5 输入二维张量,指定dim=0,重复次数为一个张量列表[n1,n2,n3],表示在(dim=0)对应行上面重复n1,n2,n3遍,张量列表的长度必须与dim=0的维度的长度一样,否则会报错
a = torch.randn(3,2)
a,torch.repeat_interleave(a,torch.tensor([2,3,4]),dim=0)#表示第一行重复2遍,第二行重复3遍,第三行重复4遍
输出结果如下:
(tensor([[-0.79, 0.54],
[-0.47, -0.25],
[-0.13, 1.03]]),
tensor([[-0.79, 0.54],
[-0.79, 0.54],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.47, -0.25],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03],
[-0.13, 1.03]]))
7. 与torch.repeat()函数区别
两个函数方法最大的区别就是repeat_interleave是一个元素一个元素地重复,而repeat是一组元素一组元素地重复.
来源:https://blog.csdn.net/flyingluohaipeng/article/details/125039411


猜你喜欢
- 一、基础知识1、MySQL-python的安装下载,然后 pip install 安装包2、python编写通用数据库程序的API规范(1)
- Session作用Session的根本作用就是在服务端存储用户和服务器会话的一些信息。典型的应用有:1、判断用户是否登录。2、购物车功能。s
- 1、模拟退火算法退火是金属从熔融状态缓慢冷却、最终达到能量最低的平衡态的过程。模拟退火算法基于优化问题求解过程与金属退火过程的相似性,以优化
- 本文实例讲述了python实现同时给多个变量赋值的方法。分享给大家供大家参考。具体分析如下:python中可以同时给多个变量赋值,下面列举了
- 一、简述MySQL版本从5直接 * 到8,相信MySQL8一定会有很多令人意想不到的改进,如果不想只会CRUD可以看看。比如系统表引擎的变化
- 背景:我们有一个用go做的项目,其中用到了zmq4进行通信,一个简单的rpc过程,早期远端是使用一个map去做ip和具体socket的映射。
- 一、前言普通机器学习:从训练数据中学习一个假设。集成方法:试图构建一组假设并将它们组合起来,集成学习是一种机器学习范式,多个学习器被训练来解
- 原问题是这样的:如何用SQL语句(不是Oracle),求出下表每一行的5个字段中的最大值,最后生成一个新字段。例如:第一行最大值 -5.0
- 引言这算是一个高级用法了,前面我们只说到对类型、变量的几种反射的用法,包括如何获取其值、其类型、以及如何重新设置新值。但是在项目应用中,另外
- 1:定义存储过程,用于分隔字符串DELIMITER $$USE `mess`$$DROP PROCEDURE IF EXISTS `spli
- 文档概览本文基于express、express-session实现了简易的登录/登出功能,完整的代码示例可以在这里找到。环境初始化首先,初始
- 在python中json分别由列表和字典组成,本文主要介绍python中字典与json相互转换的方法。使用json.dumps可以把字典转成
- function checkPhoto(fnUpload) { var filename = fnUpload.value; alert(f
- 一、技术路线requests:网页请求BeautifulSoup:解析html网页re:正则表达式,提取html网页信息os:保存文件imp
- 特么的,上次写了一堆,发现,原来下载网易云的歌曲根本不用这么费劲,直接用!http://music.163.com/song/media/o
- Worksheet 对象的 rows 属性和 columns 属性得到的是一 Generator 对象,不能用中括号取索引。可先用列表推导式
- 数据库自增 ID搞一个数据库,什么也不干,就用于生成主键。你的系统里每次得到一个 id,都需要往那个专门生成主键的数据库中通过插入
- 我们在使用 requests 这类网络请求第三方库时,可以看到它有一个参数叫做 timeout ,就是指在网络请求发出开始计算,如果超过 t
- 1.imutils功能简介imutils是在OPenCV基础上的一个封装,达到更为简结的调用OPenCV接口的目的,它可以轻松的实现图像的平
- 前言对MySQL有研究的读者,可能会发现MySQL更新很快,在安装方式上,MySQL提供了两种经典安装方式:解压式和一键式,虽然是两种安装方