python神经网络Xception模型复现详解
作者:Bubbliiiing 发布时间:2021-01-08 21:35:40
Xception是继Inception后提出的对Inception v3的另一种改进,学一学总是好的
什么是Xception模型
Xception是谷歌公司继Inception后,提出的InceptionV3的一种改进模型,其改进的主要内容为采用depthwise separable convolution来替换原来Inception v3中的多尺寸卷积核特征响应操作。
在讲Xception模型之前,首先要讲一下什么是depthwise separable convolution(深度可分离卷积块)。
深度可分离卷积块由两个部分组成,分别是深度可分离卷积和1x1普通卷积,深度可分离卷积的卷积核大小一般是3x3的,便于理解的话我们可以把它当作是特征提取,1x1的普通卷积可以完成通道数的调整。
下图为深度可分离卷积块的结构示意图:
深度可分离卷积块的目的是使用更少的参数来代替普通的3x3卷积。
我们可以进行一下普通卷积和深度可分离卷积块的对比:
假设有一个3×3大小的卷积层,其输入通道为16、输出通道为32。具体为,32个3×3大小的卷积核会遍历16个通道中的每个数据,最后可得到所需的32个输出通道,所需参数为16×32×3×3=4608个。
应用深度可分离卷积,用16个3×3大小的卷积核分别遍历16通道的数据,得到了16个特征图谱。在融合操作之前,接着用32个1×1大小的卷积核遍历这16个特征图谱,所需参数为16×3×3+16×32×1×1=656个。
可以看出来depthwise separable convolution可以减少模型的参数。
通俗地理解深度可分离卷积结构块,就是3x3的卷积核厚度只有一层,然后在输入张量上一层一层地滑动,每一次卷积完生成一个输出通道,当卷积完成后,再利用1x1的卷积调整厚度。
(视频中有些许错误,感谢zl960929的提醒,Xception使用的深度可分离卷积块SeparableConv2D也就是先深度可分离卷积再进行1x1卷积。)
对于Xception模型而言,其一共可以分为3个flow,分别是Entry flow、Middle flow、Exit flow;
分为14个block,其中Entry flow中有4个、Middle flow中有8个、Exit flow中有2个。
具体结构如下:
其内部主要结构就是残差卷积网络搭配SeparableConv2D层实现一个个block,在Xception模型中,常见的两个block的结构如下。这个主要在Entry flow和Exit flow中:
这个主要在Middle flow中:
Xception网络部分实现代码
#-------------------------------------------------------------#
# Xception的网络部分
#-------------------------------------------------------------#
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.layers import Dense,Input,BatchNormalization,Activation,Conv2D,SeparableConv2D,MaxPooling2D
from keras.layers import GlobalAveragePooling2D,GlobalMaxPooling2D
from keras import backend as K
from keras.applications.imagenet_utils import decode_predictions
def Xception(input_shape = [299,299,3],classes=1000):
img_input = Input(shape=input_shape)
#--------------------------#
# Entry flow
#--------------------------#
#--------------------#
# block1
#--------------------#
# 299,299,3 -> 149,149,64
x = Conv2D(32, (3, 3), strides=(2, 2), use_bias=False, name='block1_conv1')(img_input)
x = BatchNormalization(name='block1_conv1_bn')(x)
x = Activation('relu', name='block1_conv1_act')(x)
x = Conv2D(64, (3, 3), use_bias=False, name='block1_conv2')(x)
x = BatchNormalization(name='block1_conv2_bn')(x)
x = Activation('relu', name='block1_conv2_act')(x)
#--------------------#
# block2
#--------------------#
# 149,149,64 -> 75,75,128
residual = Conv2D(128, (1, 1), strides=(2, 2), padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv1')(x)
x = BatchNormalization(name='block2_sepconv1_bn')(x)
x = Activation('relu', name='block2_sepconv2_act')(x)
x = SeparableConv2D(128, (3, 3), padding='same', use_bias=False, name='block2_sepconv2')(x)
x = BatchNormalization(name='block2_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block2_pool')(x)
x = layers.add([x, residual])
#--------------------#
# block3
#--------------------#
# 75,75,128 -> 38,38,256
residual = Conv2D(256, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block3_sepconv1_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv1')(x)
x = BatchNormalization(name='block3_sepconv1_bn')(x)
x = Activation('relu', name='block3_sepconv2_act')(x)
x = SeparableConv2D(256, (3, 3), padding='same', use_bias=False, name='block3_sepconv2')(x)
x = BatchNormalization(name='block3_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block3_pool')(x)
x = layers.add([x, residual])
#--------------------#
# block4
#--------------------#
# 38,38,256 -> 19,19,728
residual = Conv2D(728, (1, 1), strides=(2, 2),padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block4_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv1')(x)
x = BatchNormalization(name='block4_sepconv1_bn')(x)
x = Activation('relu', name='block4_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block4_sepconv2')(x)
x = BatchNormalization(name='block4_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block4_pool')(x)
x = layers.add([x, residual])
#--------------------------#
# Middle flow
#--------------------------#
#--------------------#
# block5--block12
#--------------------#
# 19,19,728 -> 19,19,728
for i in range(8):
residual = x
prefix = 'block' + str(i + 5)
x = Activation('relu', name=prefix + '_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv1')(x)
x = BatchNormalization(name=prefix + '_sepconv1_bn')(x)
x = Activation('relu', name=prefix + '_sepconv2_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv2')(x)
x = BatchNormalization(name=prefix + '_sepconv2_bn')(x)
x = Activation('relu', name=prefix + '_sepconv3_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name=prefix + '_sepconv3')(x)
x = BatchNormalization(name=prefix + '_sepconv3_bn')(x)
x = layers.add([x, residual])
#--------------------------#
# Exit flow
#--------------------------#
#--------------------#
# block13
#--------------------#
# 19,19,728 -> 10,10,1024
residual = Conv2D(1024, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)
x = Activation('relu', name='block13_sepconv1_act')(x)
x = SeparableConv2D(728, (3, 3), padding='same', use_bias=False, name='block13_sepconv1')(x)
x = BatchNormalization(name='block13_sepconv1_bn')(x)
x = Activation('relu', name='block13_sepconv2_act')(x)
x = SeparableConv2D(1024, (3, 3), padding='same', use_bias=False, name='block13_sepconv2')(x)
x = BatchNormalization(name='block13_sepconv2_bn')(x)
x = MaxPooling2D((3, 3), strides=(2, 2), padding='same', name='block13_pool')(x)
x = layers.add([x, residual])
#--------------------#
# block14
#--------------------#
# 10,10,1024 -> 10,10,2048
x = SeparableConv2D(1536, (3, 3), padding='same', use_bias=False, name='block14_sepconv1')(x)
x = BatchNormalization(name='block14_sepconv1_bn')(x)
x = Activation('relu', name='block14_sepconv1_act')(x)
x = SeparableConv2D(2048, (3, 3), padding='same', use_bias=False, name='block14_sepconv2')(x)
x = BatchNormalization(name='block14_sepconv2_bn')(x)
x = Activation('relu', name='block14_sepconv2_act')(x)
x = GlobalAveragePooling2D(name='avg_pool')(x)
x = Dense(classes, activation='softmax', name='predictions')(x)
inputs = img_input
model = Model(inputs, x, name='xception')
model.load_weights("xception_weights_tf_dim_ordering_tf_kernels.h5")
return model
图片预测
建立网络后,可以用以下的代码进行预测。
def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x
if __name__ == '__main__':
model = Xception()
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(299, 299))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print('Input image shape:', x.shape)
preds = model.predict(x)
print(np.argmax(preds))
print('Predicted:', decode_predictions(preds))
预测所需的已经训练好的Xception模型可以在https://github.com/fchollet/deep-learning-models/releases下载。非常方便。
预测结果为:
Predicted: [[('n02504458', 'African_elephant', 0.47570863), ('n01871265', 'tusker', 0.3173351), ('n02504013', 'Indian_elephant', 0.030323735), ('n02963159', 'cardigan', 0.0007877756), ('n02410509', 'bison', 0.00075616257)]]
来源:https://blog.csdn.net/weixin_44791964/article/details/102813517
猜你喜欢
- 首先祝大家国庆节日快乐,这个假期因为我老婆要考注会,我也跟着天天去图书馆学了几天,学习的感觉还是非常不错的,这是一篇总结。这篇博客准备讲解一
- Javascript中的变量同样支持自由类型转换成为适用(或者要求)的内容以便于使用。 弱类型的Javascript不会按照程序员的愿望从实
- 1. 首先看要设置登陆的界面 book/view.py@user_util.my_login #相当于 select_all=my_logi
- Python是一门简单而文字简约的语言。阅读好的Python程序感觉就像阅读英语,尽管是非常严格的英语。Python的这种伪代码特性是其最大
- 踩坑记录:用pandas来做csv的缺失值处理时候发现奇怪BUG,就是excel打开csv文件,明明有的格子没有任何东西,当然,我就想到用p
- 先附上官方文档说明:https://pytorch.org/docs/stable/nn.functional.htmltorch.nn.f
- 用新云还不是很熟,一点点学习中。今天遇到一个文章列表前有小圆点的问题,把去除方法记一下。文章列表前有小圆点有这么几种情况:1、li的默认样式
- 模块导入方式: import osos模块是Python标准库中的一个用于访问操作系统相关功能的模块,os模块提供了一种可移植的使
- 代码很简洁,功能很实用,这里就不多废话了,直接奉上:<?php/** * 获取客户端IP * @param&nbs
- clear()方法将删除字典中的所有项目(清空字典)语法以下是clear()方法的语法:dict.clear()参数
- 多线程爬虫:即程序中的某些程序段并行执行,合理地设置多线程,可以让爬虫效率更高糗事百科段子普通爬虫和多线程爬虫分析该网址链接得出:https
- 本文实例讲述了python用来获得图片exif信息的库用法。分享给大家供大家参考。具体分析如下:exif-py是一个纯python实现的获取
- 背景:由于最近公司项目好像有点受不住并发压力了,优化迫在眉睫。由于当前系统是单数据库系统原因,能优化的地方也尽力优化了但是数据库瓶颈还是严重
- 了解如何在sublime编辑器中安装python软件包,以实现自动完成等功能,并在sublime编辑器本身中运行build。安装Sublim
- 问题: 将u'\u810f\u4e71'转换为'\u810f\u4e71'
- 前言在AI领域,来快速实现一个idea:前后端开发+部署+展现,如果走传统的前后端分离开发+服务器docker部署等方式,会很重且入门成本很
- 最近将Jesse James Garrett的《用户体验的要素》一书读了两遍,做一些简要的摘录并添加一些个人注释。当然,一本好书绝对不是简单
- 除了使用 sys.exc_info() 方法获取更多的异常信息之外,还可以使用 traceback 模块,该模块可以用来查看异常的传播轨迹,
- TCP 客户端一个使用TCP协议实现可连续对话的客户端示例代码:import socket# 客户端配置HOST = 'localh
- 如果我们的web应用有大量的异步请求,而这些异步请求是在web服务器认证的情况下,那当我们请求发生在服务器认证失效下,服务器自动302到登录