使用keras时input_shape的维度表示问题说明
作者:vita2dolce 发布时间:2023-05-10 02:52:25
Keras提供了两套后端,Theano和Tensorflow,不同的后端使用时维度顺序dim_ordering会有冲突。
对于一张224*224的彩色图片表示问题,theano使用的是th格式,维度顺序是(3,224,224),即通道维度在前,Caffe采取的也是这种方式。而Tensorflow使用的是tf格式,维度顺序是(224,224,3),即通道维度在后。
Keras默认使用的是Tensorflow。我们在导入模块的时候可以进行查看,也可以切换后端。
为了代码可以在两种后端兼容,可以通过data_format参数进行维度顺序的设定,data_format='channels_first',对应“th”,data_format='channels_last',对应“tf”。
补充知识:Tensorflow Keras 中input_shape引发的维度顺序冲突问题(NCHW与NHWC)
以tf.keras.Sequential构建卷积层为例:
tf.keras.layers.Conv2D(10, 3, input_shape=(2, 9, 9),padding='same',activation=tf.nn.relu,kernel_initializer='glorot_normal', bias_initializer='glorot_normal'),
这是一个简单的卷积层的定义,主要看input_shape参数:
这是用来指定卷积层输入形状的参数,由于Keras提供了两套后端,Theano和Tensorflow,不同的后端使用时对该参数所指代的维度顺序dim_ordering会有冲突。
Theano(th):
NCHW:顺序是 [batch, in_channels, in_height, in_width]
Tensorflow(tf):keras默认使用这种方式
NHWC:顺序是 [batch, in_height, in_width, in_channels]
即对于上述input_shape=(2, 9, 9)来说:我们先忽略batch,2会被解析为通道数,矩阵大小为9*9,符合我们预期。而tf会将矩阵大小解析为2 * 9 ,且最后一位9代表通道数,与预期不符。
解决
法一:
在卷积层定义中加入参数来让keras在两种后端之间切换:
data_format='channels_first':代表th
data_format='channels_last':代表tf
但是该法在某些时候不成功会报错:
或许是cpu电脑导致的,只支持NHWC即tf模式。
只能修改相应文件的配置来使其支持NCHW,参考这里
法二:(推荐)
使用tf.transpose函数进行高维数据的转置(维度大于2,轴的转换)
如将上述(2,9,9)转为(9,9,2)并且是以2为通道数,即矩阵为9*9,而不是像reshape函数简单的调整维度,若使用reshape函数来转换,只会得到通道数为9,矩阵为9 * 2的数据。
tf.transpose(待转矩阵,(1,2,0))
解释:
其中0,1,2…是原矩阵维度从左到右轴的标号,即(2,9,9)中三个维度分别对应标号0,1,2。而调整过后将标号顺序变为1,2,0 即是把表通道数的轴置于最后,这样转置后的矩阵就满足了keras的默认tf后端。即可正常训练。
来源:https://blog.csdn.net/hgfgfdfdff/article/details/88555423


猜你喜欢
- 关于ref和$refs的用法及讲解,vue.js中文社区( https://cn.vuejs.org/v2/api/#ref )是这么讲解的
- 利用python进行求解,求解的要求是不能使用python内部封装好的函数例如:maxway1:def findmax(data,n): i
- 除了在Matlab中使用PRTools工具箱中的svm算法,Python中一样可以使用支持向量机做分类。因为Python中的sklearn库
- sql2000的服务器版本是8.0,sql2005是9.0首先要读安装必须配置(见后记)1.我是先装2000的,安装好后打上sp4补丁,(s
- 一.图像金字塔图像金字塔是指由一组图像且不同分别率的子图集合,它是图像多尺度表达的一种,以多分辨率来解释图像的结构,主要用于图像的分割或压缩
- 很久之前就对jQuery.animate的实现非常感兴趣,不过前段时间很忙,直到前几天端午假期才有时间去研究。jQuery.animate的
- 一旦你已经为MySQL实例管理器设置了一个密码文件并且IM正在运行,你可以连接它。你可以使用mysql客户端工具通过标准MySQL API来
- 在上一篇文章中,我介绍了MySQL对XML支持的部分功能,包括--xml命令行选项,以及MySQL 5.1.5中开始引入的新功能。今天我将介
- PHP在运行时, 针对严重程度不同的错误,会给以不同的提示。 eg:在$a没声明时,直接相加,值为NULL,相加时当成0来算.但是,却提示N
- 一、首先要确保你的电脑上opencv的环境和visual studio上的环境都配置好了,测试的时候通过了没有问题。二、那么只要在你项目里面
- 一、 只复制一个表结构,不复制数据 select top 0&
- 1.背景 sysbench是一款压力测试工具,可以测试系统的硬件性能,也可以用来对数据库进行基准测试。sysbench 支持的测试
- 【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/7967225
- 前言由于一直用Linux系统,对于词典的支持特别不好,对于我这英语渣渣的人来说,当看英文文档就一直卡壳,之前用惯了有道词典,感觉很不错,虽然
- 具体方法:首先打开命令提示符;然后执行【mysql -u root -p】命令进入mysql;最后执行如下命令即可:select SUBST
- 在用ThinkPHP做tags标签的时候,出现了一个问题,就是能获取到参数,但是查不出相应的结果。查看数据库发现数据是存在的。问题出在哪了呢
- 本文总结了一些简单基本的输出格式化形式,下面话不多说了,来看看详细的介绍吧。一、打印字符串>>> print "
- 所以呢,在引用js文档的时候,要设置被引用的文档是什么编码的。 如:一个utf-8的页面引用一个gb2312的js文档,那么就要这么写 &l
- 不论是打开网页或者爬取一些资料的时候,我们想要的是计算机能在最短的时间内运行出结果,不然等待的时间过长会影响下一步工作的计划。这时候我们可以
- 字符串 -- 不可改变的序列如同大多数高级编程语言一样,变长字符串是 Python 中的基本类型。Python 在“后台”分配内存以保存字符