深入理解Pytorch中的torch. matmul()
作者:海轰Pro 发布时间:2023-06-03 05:29:18
torch.matmul()
语法
torch.matmul(input, other, *, out=None) → Tensor
作用
两个张量的矩阵乘积
行为取决于张量的维度,如下所示:
如果两个张量都是一维的,则返回点积(标量)。
如果两个参数都是二维的,则返回矩阵-矩阵乘积。
如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1。在矩阵相乘之后,前置维度被移除。
如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量积。
如果两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法
如果第一个参数是一维的,则将 1 添加到其维度,以便批量矩阵相乘并在之后删除。如果第二个参数是一维的,则将 1 附加到其维度以用于批量矩阵倍数并在之后删除
非矩阵(即批次)维度是广播的(因此必须是可广播的)
例如,如果输入是( j × 1 × n × n ) (j \times 1 \times n \times n)(j×1×n×n) 张量
另一个是 ( k × n × n ) (k \times n \times n)(k×n×n)张量,
out 将是一个 ( j × k × n × n ) (j \times k \times n \times n)(j×k×n×n) 张量
请注意,广播逻辑在确定输入是否可广播时仅查看批处理维度,而不是矩阵维度
例如
如果输入是 ( j × 1 × n × m ) (j \times 1 \times n \times m)(j×1×n×m) 张量
另一个是 ( k × m × p ) (k \times m \times p)(k×m×p) 张量
即使最后两个维度(即矩阵维度)不同,这些输入对于广播也是有效的
out 将是一个 ( j × k × n × p ) (j \times k \times n \times p)(j×k×n×p) 张量
该运算符支持 TensorFloat32。
在某些 ROCm 设备上,当使用 float16 输入时,此模块将使用不同的向后精度
举例
情形1: 一维 * 一维
如果两个张量都是一维的,则返回点积(标量)
tensor1 = torch.Tensor([1,2,3])
tensor2 =torch.Tensor([4,5,6])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
ans = 1 * 4 + 2 * 5 + 3 * 6 = 32
情形2: 二维 * 二维
如果两个参数都是二维的,则返回矩阵-矩阵乘积
也就是 正常的矩阵乘法 (m * n) * (n * k) = (m * k)
tensor1 = torch.Tensor([[1,2,3],[1,2,3]])
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
情形3: 一维 * 二维
如果第一个参数是一维的,第二个参数是二维的,为了矩阵乘法的目的,在它的维数前面加上一个 1
在矩阵相乘之后,前置维度被移除
tensor1 = torch.Tensor([1,2,3]) # 注意这里是一维
tensor2 =torch.Tensor([[4,5],[4,5],[4,5]])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
tensor1 = torch.Tensor([1,2,3])
修改为 tensor1 = torch.Tensor([[1,2,3]])
发现一个结果是[24., 30.]
一个是[[24., 30.]]
所以,当一维 * 二维时, 开始变成 1 * m(一维的维度),也就是一个二维, 再进行正常的矩阵运算,得到[[24., 30.]]
, 然后再去掉开始增加的一个维度,得到[24., 30.]
想象为二维 * 二维(前置维度为1),最后结果去掉一个维度即可
情形4: 二维 * 一维
如果第一个参数是二维的,第二个参数是一维的,则返回矩阵向量积
tensor1 =torch.Tensor([[4,5,6],[7,8,9]])
tensor2 = torch.Tensor([1,2,3])
ans = torch.matmul(tensor1, tensor2)
print('tensor1 : ', tensor1)
print('tensor2 : ', tensor2)
print('ans :', ans)
print('ans.size :', ans.size())
理解为:
把第一个二维中,想象为多个行向量
第二个一维想象为一个列向量
行向量与列向量进行矩阵乘法,得到一个标量
再按照行堆叠起来即可
情形5:两个参数至少为一维且至少一个参数为 N 维(其中 N > 2),则返回批处理矩阵乘法
第一个参数为N维,第二个参数为一维时
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).size())
(4) 先添加一个维度 (4 * 1)
得到(10 * 3 * 4) *( 4 * 1) = (10 * 3 * 1)
再删除最后一个维度(添加的那个)
得到结果(10 * 3)
tensor1 = torch.randn(10,2, 3, 4) #
tensor2 = torch.randn(4)
print(torch.matmul(tensor1, tensor2).size())
(10 * 2 * 3 * 4) * (4 * 1) = (10 * 2 * 3) 【抵消4,删1】
第一个参数为一维,第二个参数为二维时
tensor1 = torch.randn(4)
tensor2 = torch.randn(10, 4, 3)
print(torch.matmul(tensor1, tensor2).size())
tensor2 中第一个10理解为批次, 10个(4 * 3)
(1 * 4)与每个(4 * 3) 相乘得到(1,3),去除1,得到(3)
批次为10,得到(10,3)
tensor1 = torch.randn(4)
tensor2 = torch.randn(10,2, 4, 3)
print(torch.matmul(tensor1, tensor2).size())
这里批次理解为[10, 2]即可
tensor1 = torch.randn(4)
tensor2 = torch.randn(10,4, 2,4,1)
print(torch.matmul(tensor1, tensor2).size())
个人理解:当一个参数为一维时,它要去匹配另一个参数的最后两个维度(二维 * 二维)
比如上面的例子就是(1 * 4) 匹配 (4,1), 批次为(10,4,2)
高维 * 高维时
注:这不太好理解 … 感觉就是要找准批次,再进行乘法(靠感觉了 哈哈 离谱)
参考 https://pytorch.org/docs/stable/generated/torch.matmul.html#torch.matmul
来源:https://blog.csdn.net/weixin_44225182/article/details/126655303
猜你喜欢
- 此篇文章整理新手编写代码常见的一些错误,有些错误是粗心的错误,但对于新手而已,会折腾很长时间才搞定,所以在此总结下我遇到的一些问题。希望帮助
- import pyperclipimport pyautogui# PyAutoGUI中文输入需要用粘贴实现# Py
- 1 什么是曝光融合曝光融合是一种将使用不同曝光设置拍摄的图像合成为一张看起来像色调映射的高动态范围(HDR)图像的图像的方法。当我们使用相机
- 一、“无”的哲学佛家讲究“因果报应”,有果必有应。此段看似与主题没有血缘关系,实际讲的是“因”。我个人比较喜欢老子的道家思想,并喜欢以其思想
- 1.Django的简介Django是一个基于MVC构造的框架。但是在Django中,控制器接受用户输入的部分由框架自行处理,所以 Djang
- 方法一:def dict_to_numpy_method1(dict): dict_sorted=sorted(dict.iteritems
- 题目:反转一个单链表。示例:输入: 1->2->3->4->5->NULL输出: 5->4->3-
- 如下所示:# 选取等于某些值的行记录 用 == df.loc[df['column_name'] == some_value
- 阅读上一篇教程:WEB2.0网页制作标准教程(10)自适应高度布局初步搭建起来,我开始填充里面的内容。首先是定义logo图片:样式表:#lo
- 另外他们列出的这些区别有些是蛮有意义的,有些可能由于他们本人的MySQL DBA的身份,对Oracle的理解有些偏差,有些则有凑数的嫌疑.
- 1、元旦之前受赵晨之邀作为讨论嘉宾参加了ACM组织的“人与信息社会巡讲”。2、去之前赵晨发给了我大致的讨论提纲。咣当了好几下~说实话,我是硬
- import sysfrom PyQt5 import QtWidgetsfrom PyQt5.QtWidgets import QMain
- 前言老早就看到新闻员工通过人脸识别监控老板来摸鱼。有时候摸鱼太入迷了,经常在上班时间玩其他的东西被老板看到。自从在咸鱼上淘了一个树莓派3b,
- type指示type要使用的验证器。可识别的类型值为:string:类型必须为string。type 默认是 string// 校验stri
- 在python中。布尔值有 Ture False 两种。Ture等于对,False等于错。要注意在python中对字母的大小写要求非常严格。
- 如果你经常与Excel或Word打交道,那么从两份表格/文档中找到不一样的元素是一件让人很头疼的工作,当然网上有很多方法、第三方软件教你如何
- 本文实例讲述了Python中pandas模块DataFrame创建方法。分享给大家供大家参考,具体如下:DataFrame创建1. 通过列表
- 首先说明一下SQL Server内存占用由哪几部分组成。SQL Server占用的内存主要由三部分组成:数据缓存(Data Buffer)、
- Django结合ajax进行页面实时更新踩过的坑简单记录一下在使用Django、echarts和ajax实现数据动态更新时遇到的一些坑: 1
- python内存管理机制:引用计数垃圾回收内存池1. 引用计数当一个python对象被引用时 其引用计数增加 1 ; 当其不再被变量引用时