PyTorch中torch.matmul()函数常见用法总结
作者:wendy_ya 发布时间:2023-03-28 16:01:31
一、函数介绍
pytorch中两个张量的乘法可以分为两种:
两个张量对应元素相乘,在PyTorch中可以通过torch.mul函数(或*运算符)实现;
两个张量矩阵相乘,在PyTorch中可以通过torch.matmul函数实现;
torch.matmul(input, other) → Tensor
计算两个张量input和other的矩阵乘积
【注意】:matmul函数没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作。
二、常见用法
torch.matmul()也是一种类似于矩阵相乘操作的tensor连乘操作。但是它可以利用python中的广播机制,处理一些维度不同的tensor结构进行相乘操作。这也是该函数与torch.bmm()区别所在。
2.1 两个一维向量的乘积运算
若两个tensor都是一维的,则返回两个向量的点积运算结果:
import torch
x = torch.tensor([1,2])
y = torch.tensor([3,4])
print(x,y)
print(torch.matmul(x,y),torch.matmul(x,y).size())
运行结果:
tensor([1, 2]) tensor([3, 4])
tensor(11) torch.Size([])
2.2 两个二维矩阵的乘积运算
若两个tensor都是二维的,则返回两个矩阵的矩阵相乘结果:
import torch
x = torch.tensor([[1,2],[3,4]])
y = torch.tensor([[5,6,7],[8,9,10]])
print(torch.matmul(x,y),torch.matmul(x,y).size())
运行结果:
tensor([[21, 24, 27],[47, 54, 61]]) torch.Size([2, 3])
2.3 一个一维向量和一个二维矩阵的乘积运算
若input为一维,other为二维,则先将input的一维向量扩充到二维(维数前面插入长度为1的新维度),然后进行矩阵乘积,得到结果后再将此维度去掉,得到的与input的维度相同。
import torch
x = torch.tensor([1,2])
y = torch.tensor([[5,6,7],[8,9,10]])
print(torch.matmul(x,y),torch.matmul(x,y).size())
运行结果:
tensor([21, 24, 27]) torch.Size([3])
【分析】:首先将x维度从(2)扩充为(,2),然后将x(,2) 与y(2,3)进行相乘,得到(,3),最后去掉一维部分,得到(3)
2.4 一个二维矩阵和一个一维向量的乘积运算
若input为二维,other为一维,则先将other的一维向量扩充到二维(维数后面插入长度为1的新维度),然后进行矩阵乘积,得到结果后再将此维度去掉,得到的与other的维度相同。
import torch
x = torch.tensor([[1,2,3],[4,5,6]])
y = torch.tensor([7,8,9])
print(torch.matmul(x,y),'\n',torch.matmul(x,y).size())
运行结果:
tensor([ 50, 122])
torch.Size([2])
【分析】:首先y维度从(3)扩充为(3,),然后将x(2,3)与x(2,)进行相乘,得到(2,),最后去掉一维部分,得到(2)
【总结】:2.3和2.4基本类似,唯一不同的是2.3中一维向量和二维矩阵的乘积运算需要在一维向量前面插入长度为1的新维度(x为一维向量,y为二维矩阵);2.4中二维矩阵和一维向量的乘积运算需要在一维向量后面插入长度为1的新维度(x为二维矩阵,y为一维向量)。
2.5 其他
其他的暂时用不上,有需要的可以自行查阅相关资料~
参考:https://cloud.tencent.com/developer/article/1802317
来源:https://blog.csdn.net/didi_ya/article/details/121158666


猜你喜欢
- 一、存在问题在v-model想绑定表达式 || 函数方法,发现控制台报错了,不允许这波操作。下面我们分析存在该问题的原因和解决方法。实战经验
- 前言matplotlib 是Python最著名的绘图库,它提供了一整套和matlab相似的命令API,十分适合交互式地进行制图。本文将以例子
- 我就废话不多说了,大家还是直接看代码吧~# 用一行代码实现for循环初始化数组o = 10b = [ o + u for u in rang
- 新建label与button,并设置位置(grid)import tkinter as tkroot = tk.Tk()label = tk
- 前言之前搭建了一个ExtJS + spring + Oracle 的这样一个报表系统的框架。 因为其他部门的要求, 也需要这个Framewo
- 0. 前言无论在工作中,还是学习中,都会出现这样子的需求,对某张表进行了排序(按时间排序也好,其他字段排序也罢),然后获取前x行的数据,由于
- 一、需求描述手上有大量外文文档(本案例以5份为例,分别命名为 test1.docx test2.docx 以此
- 本文基本使用谷歌翻译加上自己的理解,权当加深记忆。npm简介qs 是一个增加了一些安全性的查询字符串解析和序列化字符串的库。主要维护者:Jo
- PHP get_html_translation_table() 函数实例输出 htmlspecialchars 函数使用的翻译表:<
- 我就废话不多说了,大家还是直接看代码吧~func ReadLine(fileName string) ([]string,error){f,
- 1 简介DataFrame是Python中Pandas库中的一种数据结构,它类似excel,是一种二维表。或许说它可能有点像matlab的矩
- ——nodejs安装及环境配置1.nodejs官网,下载windows平台nodejs环境安装包(.msi格式),安装2.测试安装是否成功:
- 引言在做接口测试的时候,我们不仅需要将测试结果以报告的形式展示,还需要将测试结果以邮件的形式发送到需要知道的人手中。那么如何发送邮件呢?邮件
- 字典排序在程序中使用字典进行数据信息统计时,由于字典是无序的所以打印字典时内容也是无序的。因此,为了使统计得到的结果更方便查看需要进行排序。
- 目录前言1-下载python3.8压缩包2-解压缩安装包3-安装依赖工具4-安装python3.85-修改python2软链接6-修改yum
- Python自带一个轻量级的关系型数据库SQLite。这一数据库使用SQL语言。SQLite作为后端数据库,可以搭配Python建网站,或者
- <base href="http://digi.tech.qq.com/images/ld/2007/1022/
- 代码如下: var params = new Enumerator(Request.QueryString); while (!params
- 百度的资料,保存下来:在写按时间段查询的sql语句的时候 一般我们会这么写查询条件:where date>='2010-01-
- nav导航栏<nav role="navigation" class="navbar navbar-de