TensorFlow中tf.batch_matmul()的用法
作者:yyhhlancelot 发布时间:2022-06-06 10:33:37
标签:TensorFlow,tf.batch,matmul
TensorFlow中tf.batch_matmul()用法
如果有两个三阶张量,size分别为
a.shape = [100, 3, 4]
b.shape = [100, 4, 5]
c = tf.batch_matmul(a, b)
则c.shape = [100, 3, 5] //将每一对 3x4 的矩阵与 4x5 的矩阵分别相乘。batch_size不变
100为张量的batch_size。剩下的两个维度为数据的维度。
不过新版的tensorflow已经移除了上面的函数,使用时换为tf.matmul就可以了。与上面注释的方式是同样的。
附: 如果是更高维度。例如(a, b, m, n) 与(a, b, n, k)之间做matmul运算。则结果的维度为(a, b, m, k)。
TensorFlow如何实现batch_matmul
我们知道,在tensorflow早期版本中有tf.batch_matmul()函数,可以实现多维tensor和低维tensor的直接相乘,这在使用过程中非常便捷。
但是最新版本的tensorflow现在只有tf.matmul()函数可以使用,不过只能实现同维度的tensor相乘, 下面的几种方法可以实现batch matmul的可能。
例如: tensor A(batch_size,m,n), tensor B(n,k),实现batch matmul 使得A * B。
方法1: 利用tf.matmul()
对tensor B 进行增维和扩展
A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
B_exp = tf.tile(tf.expand_dims(B,0),[batch_size, 1, 1]) #先进行增维再扩展
C = tf.matmul(A, B_exp)
方法2: 利用tf.reshape()
对tensor A 进行reshape操作,然后利用tf.matmul()
A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
A = tf.reshape(A, [-1, 3])
C = tf.reshape(tf.matmul(A, B), [-1, 2, 5])
方法3: 利用tf.scan()
利用tf.scan() 对tensor按第0维进行展开的特性
A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
initializer = tf.Variable(tf.random_normal(shape=(2,5)))
C = tf.scan(lambda a,x: tf.matmul(x, B), A, initializer)
方法4: 利用tf.einsum()
A = tf.Variable(tf.random_normal(shape=(batch_size, 2, 3)))
B = tf.Variable(tf.random_normal(shape=(3, 5)))
C = tf.einsum('ijk,kl->ijl',A,B)
来源:https://blog.csdn.net/yyhhlancelot/article/details/81191923
0
投稿
猜你喜欢
- <?php date_default_timezone_set("PRC"); $host = stripslas
- 今天继续给大家介绍Python相关知识,本文主要内容是Python asyncio异步编程简单实现。一、asyncio事件循环简介async
- 写这个的目地,主要是系统理下目前产品设计的流程,提醒自己尽量去避免一些常见的问题,也能让大家系统的了解天极网的产品设计流程。当然,不保证任何
- 在向大家详细介绍Linux mysql之前,首先让大家了解下Linux mysql,然后全面介绍Linux mysql,希望对大家有用。1.
- 在命令行输入以下代码:pythonimport cv2cv2.__version__来源:https://blog.csdn.net/dlh
- 1、环境PyCharmPython 3.6pip安装的依赖包包括:requests 2.25.0、urllib3 1.26.2、docx 0
- 最近,由于工作需要统计一下文本文档中的各种不同类字符的数量。将txt文本文档中包含的的中文、英文、数字等字符数量进行统计。这当然可以使用py
- 一、引言有一定 Python 编程经验的人估计十有八九使用过异常,异常对于程序的健壮性是毋庸置疑的。二、使用异常对数据进行初始化在某些条件下
- chatGPT已经爆火一段时间了,我想大多数的开发者都在默默的在开发和测试当中,可能也是因为这个原因所以现在很难找到关于开发中遇到的一些坑或
- 1. 安装vim:# apt-get install -y vim-gnome2. 安装ctags,ctags用于支持tagli
- HP注释规范注释在写代码的过程中非常重要,好的注释能让你的代码读起来更轻松,在写代码的时候一定要注意注释的规范。“php是一门及其容易入门的
- QThread是Qt的线程类中最核心的底层类。由于PyQt的的跨平台特性,QThread要隐藏所有与平台相关的代码要使用的QThread开始
- 1.客户端的主页面:<%@LANGUAGE="VBSCRIPT" CODEPAGE="936"
- 代码如下:<% '/* 函数名称:Zxj_ReplaceHtmlClearHtml '/
- 系统环境:64位win7企业版python2.7.102016.08.16修改内容:1)read_until()函数是可以设置timeout
- 上文:栅格:从混乱到秩序Jacci Howard Bear 的英文原文:http://desktoppub.about.com/od/gri
- reflect反射首先,我们要区分两个概念——“标识名”和&
- 前言Multiprocessing.Pool可以提供指定数量的进程供用户调用,当有新的请求提交到pool中时,如果池还没有满,那么就会创建一
- String Types(字符串类型)字符串类型Mysql支持多种字符串类型的变体。 这些数据类型在4.1和5.0版本中有较大的变化, 这使
- 文章主要描述的是SQL Server聚集索引的指示(Cluster Index Indications),在实际操作中借助聚集索引来进行搜索