网络编程
位置:首页>> 网络编程>> Python编程>> Tensorflow进行多维矩阵的拆分与拼接实例

Tensorflow进行多维矩阵的拆分与拼接实例

作者:张叫张大卫  发布时间:2021-11-29 22:20:58 

标签:Tensorflow,多维,矩阵

最近在使用tensorflow进行网络训练的时候,需要提取出别人训练好的卷积核的部分层的数据。由于tensorflow中的tensor和python中的list不同,无法直接使用加法进行拼接,后来发现一个函数可以完成tensor的拼接。

函数形式如下:


tf.concat(concat_dim,values,name='concat')

其中,第一个参数表示需要拼接的多维tensor,并且可以将多个tensor同事拼接,第二个表示按照哪一个维度拼接(从数字0开始)。

例子:创建一个三维的tensor,然后分别取出最后一个维度(注意:tensor支持与python中list相似的切片操作,可以使用这种方式进行拆分),然后在拼接在一起。


import tensorflow as tf

weights=tf.Variable(tf.truncated_normal([2,3,4],dtype=tf.float32,stddev=1e-1),name='weights')

weight1=weights[0:2,0:3,1:2]
weight2=weights[0:2,0:3,2:3]
weight3=weights[0:2,0:3,1:2]
weight4=tf.concat([weight1,weight2,weight3],2) #2表示最后一个维度

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(weights))
print("****************")
print(sess.run(weight4))

Tensorflow进行多维矩阵的拆分与拼接实例

来源:https://blog.csdn.net/weixin_40100431/article/details/82858085

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com