使用keras时input_shape的维度表示问题说明
作者:vita2dolce 发布时间:2023-05-10 02:52:25
Keras提供了两套后端,Theano和Tensorflow,不同的后端使用时维度顺序dim_ordering会有冲突。
对于一张224*224的彩 * 片表示问题,theano使用的是th格式,维度顺序是(3,224,224),即通道维度在前,Caffe采取的也是这种方式。而Tensorflow使用的是tf格式,维度顺序是(224,224,3),即通道维度在后。
Keras默认使用的是Tensorflow。我们在导入模块的时候可以进行查看,也可以切换后端。
为了代码可以在两种后端兼容,可以通过data_format参数进行维度顺序的设定,data_format='channels_first',对应“th”,data_format='channels_last',对应“tf”。
补充知识:Tensorflow Keras 中input_shape引发的维度顺序冲突问题(NCHW与NHWC)
以tf.keras.Sequential构建卷积层为例:
tf.keras.layers.Conv2D(10, 3, input_shape=(2, 9, 9),padding='same',activation=tf.nn.relu,kernel_initializer='glorot_normal', bias_initializer='glorot_normal'),
这是一个简单的卷积层的定义,主要看input_shape参数:
这是用来指定卷积层输入形状的参数,由于Keras提供了两套后端,Theano和Tensorflow,不同的后端使用时对该参数所指代的维度顺序dim_ordering会有冲突。
Theano(th):
NCHW:顺序是 [batch, in_channels, in_height, in_width]
Tensorflow(tf):keras默认使用这种方式
NHWC:顺序是 [batch, in_height, in_width, in_channels]
即对于上述input_shape=(2, 9, 9)来说:我们先忽略batch,2会被解析为通道数,矩阵大小为9*9,符合我们预期。而tf会将矩阵大小解析为2 * 9 ,且最后一位9代表通道数,与预期不符。
解决
法一:
在卷积层定义中加入参数来让keras在两种后端之间切换:
data_format='channels_first':代表th
data_format='channels_last':代表tf
但是该法在某些时候不成功会报错:
或许是cpu电脑导致的,只支持NHWC即tf模式。
只能修改相应文件的配置来使其支持NCHW,参考这里
法二:(推荐)
使用tf.transpose函数进行高维数据的转置(维度大于2,轴的转换)
如将上述(2,9,9)转为(9,9,2)并且是以2为通道数,即矩阵为9*9,而不是像reshape函数简单的调整维度,若使用reshape函数来转换,只会得到通道数为9,矩阵为9 * 2的数据。
tf.transpose(待转矩阵,(1,2,0))
解释:
其中0,1,2…是原矩阵维度从左到右轴的标号,即(2,9,9)中三个维度分别对应标号0,1,2。而调整过后将标号顺序变为1,2,0 即是把表通道数的轴置于最后,这样转置后的矩阵就满足了keras的默认tf后端。即可正常训练。
来源:https://blog.csdn.net/hgfgfdfdff/article/details/88555423
猜你喜欢
- 起步Python 提供的多线程模型中并没有提供读写锁,读写锁相对于单纯的互斥锁,适用性更高,可以多个线程同时占用读模式的读写锁,但是只能一个
- 测试题defer有一些规则,如果不了解,代码实现的最终结果会与预期不一致。对于这些规则,你了解吗?这是关于defer使用的代码,可以先考虑一
- 用XMlhttp生成html页面,相关函数如下:<% ’定义xmlhttp function Get
- 不久前因业务需要,我在自己的笔记本中安装了搜霸。当时一个做平面的朋友过来和我做一些设计交流,我在笔记本前准备输入一个网址,他靠近我的电脑,大
- 在这篇入门教程中,我们假定你已经有了PHP语言程序、MySQL数据库、计算机网络通讯及XML语言基础。如果你还没有,那么请先学习相关知识。我
- 在 Python 2 中 xrange() 创建迭代对象的用法是非常流行的。比如: for 循环或者是列表/集合/字典推导式。这个表现十分像
- FLASH 全屏有二类四种:1、不用浏览器直接用FLASH播放器播放的类型:A、不显示FLASH播放器菜单栏的全屏(类似屏保效果),在第一帧
- 一个客户提供一个股价的信息,要求放在页面上,显示一些数据,需要从远程获取xml,然后解析写在网页上,开始不会觉得很难,其实蛮简单的,先用ja
- 本文实例讲述了Python使用sklearn库实现的各种分类算法简单应用。分享给大家供大家参考,具体如下:KNNfrom sklearn.n
- 废话真的一句也不想多说,直接看代码吧!# -*- coding: utf-8 -*- import numpy from sklearn i
- 大家可能有这样的体验,好比在程序里面我明明写了app.run(port=8001),结果程序还是在5000端口输出,我们右键点击py程序,直
- object.OpenTextFile(filename[, iomode[, create[, format]]]) 参数 object
- 作用:用ASP程序将页面中的电话号码生成图片格式。 代码如下:<% Call Com_CreatValidCode
- 本文通过实例代码介绍了如何在jscript和vbscript中使用操作FileSystemObject(fso)对象模式来编程.
- 前言Golang语言有诸多优点:静态编译、协程、堪比c语言的高性能。但是也有一些令人发指的地方 —— 经常被人调侃 五行代码,三行错误处理
- 所使用python环境为最新的3.6版本一、安装pdfminer模块 安装anaconda后,直接可以通过pip安装pip install
- 阅读上一片:微软建议的ASP性能优化28条守则(1)技巧 3:将数据和 HTML 缓存在 Web 服务器的磁盘上有时,数据可能太多,无法都缓
- 一、把一个字符串的内容提取出来,并放到字典中流程如下: 1、得到字符串s,通过分割提取得到s1(是个列表) s=”name=lyy&
- 每位SQL Server开发员都有自己的首选操作方法。我的方法叫做分子查询。这些是由原子查询组合起来的查询,通过它们我可以处理一个表格。将原
- 利用seek监控文件内容,并打印出变化内容:#/usr/bin/env python#-*- coding=utf-8 -*-pos = 0