Tensorflow矩阵运算实例(矩阵相乘,点乘,行/列累加)
作者:Kenn7 发布时间:2023-12-19 04:07:42
标签:Tensorflow,矩阵,相乘,点乘
Tensorflow二维、三维、四维矩阵运算(矩阵相乘,点乘,行/列累加)
1. 矩阵相乘
根据矩阵相乘的匹配原则,左乘矩阵的列数要等于右乘矩阵的行数。
在多维(三维、四维)矩阵的相乘中,需要最后两维满足匹配原则。
可以将多维矩阵理解成:(矩阵排列,矩阵),即后两维为矩阵,前面的维度为矩阵的排列。
比如对于(2,2,4)来说,视为2个(2,4)矩阵。
对于(2,2,2,4)来说,视为2*2个(2,4)矩阵。
import tensorflow as tf
a_2d = tf.constant([1]*6, shape=[2, 3])
b_2d = tf.constant([2]*12, shape=[3, 4])
c_2d = tf.matmul(a_2d, b_2d)
a_3d = tf.constant([1]*12, shape=[2, 2, 3])
b_3d = tf.constant([2]*24, shape=[2, 3, 4])
c_3d = tf.matmul(a_3d, b_3d)
a_4d = tf.constant([1]*24, shape=[2, 2, 2, 3])
b_4d = tf.constant([2]*48, shape=[2, 2, 3, 4])
c_4d = tf.matmul(a_4d, b_4d)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print("# {}*{}={} \n{}".
format(a_2d.eval().shape, b_2d.eval().shape, c_2d.eval().shape, c_2d.eval()))
print("# {}*{}={} \n{}".
format(a_3d.eval().shape, b_3d.eval().shape, c_3d.eval().shape, c_3d.eval()))
print("# {}*{}={} \n{}".
format(a_4d.eval().shape, b_4d.eval().shape, c_4d.eval().shape, c_4d.eval()))
2. 点乘
点乘指的是shape相同的两个矩阵,对应位置元素相乘,得到一个新的shape相同的矩阵。
a_2d = tf.constant([1]*6, shape=[2, 3])
b_2d = tf.constant([2]*6, shape=[2, 3])
c_2d = tf.multiply(a_2d, b_2d)
a_3d = tf.constant([1]*12, shape=[2, 2, 3])
b_3d = tf.constant([2]*12, shape=[2, 2, 3])
c_3d = tf.multiply(a_3d, b_3d)
a_4d = tf.constant([1]*24, shape=[2, 2, 2, 3])
b_4d = tf.constant([2]*24, shape=[2, 2, 2, 3])
c_4d = tf.multiply(a_4d, b_4d)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print("# {}*{}={} \n{}".
format(a_2d.eval().shape, b_2d.eval().shape, c_2d.eval().shape, c_2d.eval()))
print("# {}*{}={} \n{}".
format(a_3d.eval().shape, b_3d.eval().shape, c_3d.eval().shape, c_3d.eval()))
print("# {}*{}={} \n{}".
format(a_4d.eval().shape, b_4d.eval().shape, c_4d.eval().shape, c_4d.eval()))
另外,点乘的其中一方可以是一个常数,也可以是一个和矩阵行向量等长(即列数)的向量。
即
因为在点乘过程中,会自动将常数或者向量进行扩维。
a_2d = tf.constant([1]*6, shape=[2, 3])
k = tf.constant(2)
l = tf.constant([2, 3, 4])
b_2d_1 = tf.multiply(k, a_2d) # tf.multiply(a_2d, k) is also ok
b_2d_2 = tf.multiply(l, a_2d) # tf.multiply(a_2d, l) is also ok
a_3d = tf.constant([1]*12, shape=[2, 2, 3])
b_3d_1 = tf.multiply(k, a_3d) # tf.multiply(a_3d, k) is also ok
b_3d_2 = tf.multiply(l, a_3d) # tf.multiply(a_3d, l) is also ok
a_4d = tf.constant([1]*24, shape=[2, 2, 2, 3])
b_4d_1 = tf.multiply(k, a_4d) # tf.multiply(a_4d, k) is also ok
b_4d_2 = tf.multiply(l, a_4d) # tf.multiply(a_4d, l) is also ok
with tf.Session() as sess:
tf.global_variables_initializer().run()
print("# {}*{}={} \n{}".
format(k.eval().shape, a_2d.eval().shape, b_2d_1.eval().shape, b_2d_1.eval()))
print("# {}*{}={} \n{}".
format(l.eval().shape, a_2d.eval().shape, b_2d_2.eval().shape, b_2d_2.eval()))
print("# {}*{}={} \n{}".
format(k.eval().shape, a_3d.eval().shape, b_3d_1.eval().shape, b_3d_1.eval()))
print("# {}*{}={} \n{}".
format(l.eval().shape, a_3d.eval().shape, b_3d_2.eval().shape, b_3d_2.eval()))
print("# {}*{}={} \n{}".
format(k.eval().shape, a_4d.eval().shape, b_4d_1.eval().shape, b_4d_1.eval()))
print("# {}*{}={} \n{}".
format(l.eval().shape, a_4d.eval().shape, b_4d_2.eval().shape, b_4d_2.eval()))
4. 行/列累加
a_2d = tf.constant([1]*6, shape=[2, 3])
d_2d_1 = tf.reduce_sum(a_2d, axis=0)
d_2d_2 = tf.reduce_sum(a_2d, axis=1)
a_3d = tf.constant([1]*12, shape=[2, 2, 3])
d_3d_1 = tf.reduce_sum(a_3d, axis=1)
d_3d_2 = tf.reduce_sum(a_3d, axis=2)
a_4d = tf.constant([1]*24, shape=[2, 2, 2, 3])
d_4d_1 = tf.reduce_sum(a_4d, axis=2)
d_4d_2 = tf.reduce_sum(a_4d, axis=3)
with tf.Session() as sess:
tf.global_variables_initializer().run()
print("# a_2d 行累加得到shape:{}\n{}".format(d_2d_1.eval().shape, d_2d_1.eval()))
print("# a_2d 列累加得到shape:{}\n{}".format(d_2d_2.eval().shape, d_2d_2.eval()))
print("# a_3d 行累加得到shape:{}\n{}".format(d_3d_1.eval().shape, d_3d_1.eval()))
print("# a_3d 列累加得到shape:{}\n{}".format(d_3d_2.eval().shape, d_3d_2.eval()))
print("# a_4d 行累加得到shape:{}\n{}".format(d_4d_1.eval().shape, d_4d_1.eval()))
print("# a_4d 列累加得到shape:{}\n{}".format(d_4d_2.eval().shape, d_4d_2.eval()))
来源:https://blog.csdn.net/kane7csdn/article/details/84843154


猜你喜欢
- 有时表或结果集包含重复的记录。有时它是允许的,但有时它需要停止重复的记录。有时它需要识别重复的记录从表中删除。本章将介绍如何防止发生在一个表
- 昨天在用用Pycharm读取一个200+M的CSV的过程中,竟然出现了Memory Error!简直让我怀疑自己买了个假电脑,毕竟是8G内存
- 小小程序猿SQL Server认知的成长 1.没毕业或工作没多久,只知道有数据库、SQL这么个东东,浑然分不清SQL和Sql Server
- 1、单个关键字加亮代码: <div id="txt"> 用JS让文章内容指定
- 再dos中无法使用pip,命令主要是没有发现这个命令。我们先找到这个命令的位置,一般是在python里面的Scripts文件夹里面。我们可以
- 写在前面之前的文章中已经讲过了遗传算法的基本流程,并且用MATLAB实现过一遍了。这一篇文章主要面对的人群是看过了我之前的文章,因此我就不再
- 在Django的开发过程中,有一些功能是通过JS根据用户的不同选择来加载页面中的某一部分(子页面)的。如果子页面中有我们需要传入的值。可以这
- 最近在研究深度学习视觉相关的东西,经常需要写python代码搭建深度学习模型。比如写CNN模型相关代码时,我们需要借助python图像库来读
- 我就废话不多说了,直接上代码吧!from os import listdirimport osfrom time import timeim
- 1. 原理对于DNA序列,一阶马尔科夫链可以理解为当前碱基的类型仅取决于上一位碱基类型。如图1所示,一条序列的开端(由B开始)可能是A、T、
- 本文实例为大家分享了python版DDOS攻击脚本,供大家参考,具体内容如下于是就找到了我之前收藏的一篇python的文章,是关于ddos攻
- 根据教程实现了读取csv文件前面的几行数据,一下就想到了是不是可以实现前面几列的数据。经过多番尝试总算试出来了一种方法。之所以想实现读取前面
- 如下所示:def getWordPattern(word): pattern = [] usedLetter={} count=0 for
- 对于数据库这一块询问比较多的就是在 MySQL 中怎么去选择一种何时当前业务需求的存储引擎,而 MySQL 中支持的存储引擎又有很多种,那么
- 本文介绍的MySQL数据库的出错代码表,依据MySQL数据库头文件mysql/include/mysqld_error.h整理而成。详细内容
- 1、python多进程编程背景python中的多进程最大的好处就是充分利用多核cpu的资源,不像python中的多线程,受制于GIL的限制,
- set是什么?数学上,把set称做由不同的元素组成的集合,集合(set)的成员通常被称做集合元素(set elements)。Python把
- 高斯模糊的介绍与原理通常,图像处理软件会提供"模糊"(blur)滤镜,使图片产生模糊的效果。"模糊"
- 复习回顾:Python 对于时间日期操作提供了很多方法,我们前面已经学习了2个模块:基于Unix 时间戳范围限制在1970~2038年的时间
- 一、前言程序的性能也是非常关键的指标,很多时候你的代码跑的快,更能够体现你的技术。最近发现很多小伙伴在性能分析的过程中都是手动打印运行时间的