Pytorch深度学习addmm()和addmm_()函数用法解析
作者:悲恋花丶无心之人 发布时间:2021-01-02 04:04:25
一、函数解释
在torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:
换句话说,就是需要传入5个参数,mat里的每个元素乘以beta,mat1和mat2进行矩阵乘法(左行乘右列)后再乘以alpha,最后将这2个结果加在一起。但是这样说可能没啥概念,接下来博主为大家写上一段代码,大家就明白了~
def addmm(self, beta=1, mat, alpha=1, mat1, mat2, out=None): # real signature unknown; restored from __doc__
"""
addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> Tensor
Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
The matrix :attr:`mat` is added to the final result.
If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a
:math:`(m \times p)` tensor, then :attr:`mat` must be
:ref:`broadcastable <broadcasting-semantics>` with a :math:`(n \times p)` tensor
and :attr:`out` will be a :math:`(n \times p)` tensor.
:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between
:attr:`mat1` and :attr`mat2` and the added matrix :attr:`mat` respectively.
.. math::
out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)
For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
:attr:`alpha` must be real numbers, otherwise they should be integers.
Args:
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
mat (Tensor): matrix to be added
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
mat1 (Tensor): the first matrix to be multiplied
mat2 (Tensor): the second matrix to be multiplied
out (Tensor, optional): the output tensor
Example::
>>> M = torch.randn(2, 3)
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.addmm(M, mat1, mat2)
tensor([[-4.8716, 1.4671, -1.3746],
[ 0.7573, -3.9555, -2.8681]])
"""
pass
二、代码范例
1.先摆出代码,大家可以先复制粘贴运行一下,在之后博主会一一讲解
"""
@author:nickhuang1996
"""
import torch
rectangle_height = 3
rectangle_width = 3
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):
for j in range(rectangle_width):
inputs[i] = i * torch.ones(rectangle_width)
'''
inputs and its transpose
-->inputs = tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
-->inputs_t = tensor([[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]])
'''
print("inputs:\n", inputs)
inputs_t = inputs.t()
print("inputs_t:\n", inputs_t)
'''
inputs_t @ inputs_t [[0., 1., 2.], [[0., 1., 2.], [[0., 3., 6.]
= [0., 1., 2.], @ [0., 1., 2.], = [0., 3., 6.]
[0., 1., 2.]] [0., 1., 2.]] [0., 3., 6.]]
'''
'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
print("a:\n", a)
print("b:\n", b)
print("c:\n", c)
print("d:\n", d)
print("e:\n", e)
print("f:\n", f)
print("g:\n", g)
print("g2:\n", g2)
print("h:\n", h)
print("h12:\n", h12)
print("h21:\n", h21)
print("inputs:\n", inputs)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
'''
inputs @ inputs_t [[0., 0., 0.], [[0., 1., 2.], [[0., 0., 0.]
= [1., 1., 1.], @ [0., 1., 2.], = [0., 3., 6.]
[2., 2., 2.]] [0., 1., 2.]] [0., 6., 12.]]
'''
inputs.addmm_(1, -2, inputs, inputs_t) # In-place
print("inputs:\n", inputs)
2.其中
inputs是一个3×3的矩阵,为
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
inputs_t也是一个3×3的矩阵,是inputs的转置矩阵,为
tensor([[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]])
* inputs_t @ inputs_t为
'''
inputs_t @ inputs_t [[0., 1., 2.], [[0., 1., 2.], [[0., 3., 6.]
= [0., 1., 2.], @ [0., 1., 2.], = [0., 3., 6.]
[0., 1., 2.]] [0., 1., 2.]] [0., 3., 6.]]
'''
3.代码中a,b,c和d展示的是完全形式,即标明了位置参数和传入参数。可以看到input这个位置参数可以写在函数的前面,即
torch.addmm(input, mat1, mat2) = inputs.addmm(mat1, mat2)
完成的公式为:
1 × inputs + 1 ×(inputs_t @ inputs_t)
'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
a:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
b:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
c:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
d:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
4.下面的例子更好了说明了input参数的位置可变性,并且beta和alpha都缺省了:
完成的公式为:
1 × inputs + 1 ×(inputs_t @ inputs_t)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
e:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
f:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
5.加一个参数,实际上是添加了beta这个参数
完成的公式为:
g = 1 × inputs + 1 ×(inputs_t @ inputs_t)
g2 = 2 × inputs + 1 ×(inputs_t @ inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
g:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
g2:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
6.再加一个参数,实际上是添加了alpha这个参数
完成的公式为:
h = 1 × inputs + 1 ×(inputs_t @ inputs_t)
h12 = 1 × inputs + 2 ×(inputs_t @ inputs_t)
h21 = 2 × inputs + 1 ×(inputs_t @ inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
h:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
h12:
tensor([[ 0., 6., 12.],
[ 1., 7., 13.],
[ 2., 8., 14.]])
h21:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
7.当然,以上的步骤inputs没有变化,还是为
inputs:
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
8.addmm_()的操作和addmm()函数功能相同,区别就是addmm_()有inplace的操作,也就是在原对象基础上进行修改,即把改变之后的变量再赋给原来的变量。例如:
inputs的值变成了改变之后的值,不用再去写 某个变量=addmm_() 了,因为inputs就是改变之后的变量!
*inputs@ inputs_t为
'''
inputs @ inputs_t [[0., 0., 0.], [[0., 1., 2.], [[0., 0., 0.]
= [1., 1., 1.], @ [0., 1., 2.], = [0., 3., 6.]
[2., 2., 2.]] [0., 1., 2.]] [0., 6., 12.]]
'''
完成的公式为:
inputs = 1 × inputs - 2 ×(inputs @ inputs_t)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
inputs.addmm_(1, -2, inputs, inputs_t) # In-place
inputs:
tensor([[ 0., 0., 0.],
[ 1., -5., -11.],
[ 2., -10., -22.]])
三、代码运行结果
inputs:
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
inputs_t:
tensor([[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]])
a:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
b:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
c:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
d:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
e:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
f:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
g:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
g2:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
h:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
h12:
tensor([[ 0., 6., 12.],
[ 1., 7., 13.],
[ 2., 8., 14.]])
h21:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
inputs:
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
inputs:
tensor([[ 0., 0., 0.],
[ 1., -5., -11.],
[ 2., -10., -22.]])
来源:https://nickhuang1996.blog.csdn.net/article/details/90638449


猜你喜欢
- 前言虽然现在文件上传下载工具多如牛毛,比如http、ftp、sftp、scp等方案都可以用于文件传输,但都是需要安装服务器甚至客户端。有一种
- 引言最近python语言大火,除了在科学计算领域python有用武之地之外,在游戏、后台等方面,python也大放异彩,本篇博文将按照正规的
- GIT安装访问: https://git-scm.com/downloads ,进入git'下载页面,根据个人操作系统下载对应软件版
- Python中会遇到很多关于排序的问题,今天小编就带给大家实现插入排序的方法。在Python中插入排序的基本原理类似于摸牌,将摸起来的牌插入
- 一、为表创建自增长自段有两种,一种是不同的表使用各自的Sequence,方法如下: 1、在Oracle sequence首先创建sequen
- Mysql的安装方法 安装mysql的步骤如下:请注意按图中所示,有些选项和默认是不一样的。同时,如果您是重新安装mysql的话,要注意先备
- 最近在实习,boss给布置了一个python的小任务,学习过程中发现copy()和deepcopy()这对好 * 实在是有点过分,搞的博主就有
- 列表的格式:变量A的类型为列表 namesList = ['xiaoWang','xiaoZhang',
- 在javascript中,我们都知道使用var来声明变量。javascript是函数级作用域,函数内可以访问函数外的变量,函数外不能访问函数
- 本文采用OpenCV3和Python3 来实现静态图片的人脸识别,采用的是Haar文件级联。 首先需要将OpenCV3源代码中找到data文
- 一、ref的基本使用ref的使用<!-- `vm.$refs.p`将会是DOM结点 --><p ref="p&q
- PyCharm使用jre,所以设置内存使用的情况和eclipse类似。编辑PyCharm安装目录下PyCharm 4.5.3\bin下的py
- 前言在做小程序列表展示的时候,接到了一个需求。需要在列表展示的时候加上动画效果。设计视频效果如下图:需要在进入列表页的时候,依次展示每一条卡
- 英文文档:staticmethod(function)Return a static method for function.A stati
- 一、简介py2exe是一个将python脚本转换成windows上的可独立执行的可执行程序(*.exe)的工具,这样,你就可以不用装pyth
- 前言有趣的实战项目,用Python+xlwings模块制作天气预报表让我们愉快地开始吧~开发工具Python版本: 3.6.4相关模块:re
- “Lightbox”是一个别致且易用的图片显示效果,它可以使图片直接呈现在当前页面之上而不用转到新的窗口。lightbox效果网络上有很多j
- 测了一下django、flask、bottle、tornado 框架本身最简单的性能。对django的性能完全无语了。django、flas
- 1.首先在Pycharm Tools->Deployment->Configurations打开新建SFTP输入host: ip
- 我使用的 Pandas 版本如下,顺便也导入 Pandas 库。>>> import pandas as pd>&g