python神经网络Densenet模型复现详解
作者:Bubbliiiing 发布时间:2022-02-13 06:43:13
什么是Densenet
据说Densenet比Resnet还要厉害,我决定好好学一下。
ResNet模型的出现使得深度学习神经网络可以变得更深,进而实现了更高的准确度。
ResNet模型的核心是通过建立前面层与后面层之间的短路连接(shortcuts),这有助于训练过程中梯度的反向传播,从而能训练出更深的CNN网络。
DenseNet模型,它的基本思路与ResNet一致,也是建立前面层与后面层的短路连接,不同的是,但是它建立的是前面所有层与后面层的密集连接。
DenseNet还有一个特点是实现了特征重用。
这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。
DenseNet示意图如下:
代码下载
Densenet
1、Densenet的整体结构
如图所示Densenet由DenseBlock和中间的间隔模块Transition Layer组成。
1、DenseBlock:DenseBlock指的就是DenseNet特有的模块,如下图所示,前面所有层与后面层的具有密集连接,在同一个DenseBlock当中,特征层的高宽不会发生改变,但是通道数会发生改变。
2、Transition Layer:Transition Layer是将不同DenseBlock之间进行连接的模块,主要功能是整合上一个DenseBlock获得的特征,并且缩小上一个DenseBlock的宽高,在Transition Layer中,一般会使用一个步长为2的AveragePooling2D缩小特征层的宽高。
2、DenseBlock
DenseBlock的实现示意图如图所示:
以前获得的特征会在保留后不断的堆叠起来。
以一个简单例子来表现一下具体的DenseBlock的流程:
假设输入特征层为X0。
1、对x0进行一次1x1卷积调整通道数到4*32后,再利用3x3卷积获得一个32通道的特征层,此时会获得一个shape为(h,w,32)的特征层x1。
2、将获得的x1和初始的x0堆叠,获得一个新的特征层,这个特征层会同时保留初始x0的特征也会保留经过卷积处理后的特征。
3、反复经过步骤1、2的处理,原始的特征会一直得到保留,经过卷积处理后的特征也会得到保留。当网络程度不断加深,就可以实现前面所有层与后面层的具有密集连接。
实现代码为:
def dense_block(x, blocks, name):
for i in range(blocks):
x = conv_block(x, 32, name=name + '_block' + str(i + 1))
return x
def conv_block(x, growth_rate, name):
bn_axis = 3
x1 = layers.BatchNormalization(axis=bn_axis,
epsilon=1.001e-5,
name=name + '_0_bn')(x)
x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
x1 = layers.Conv2D(4 * growth_rate, 1,
use_bias=False,
name=name + '_1_conv')(x1)
x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x1)
x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
x1 = layers.Conv2D(growth_rate, 3,
padding='same',
use_bias=False,
name=name + '_2_conv')(x1)
x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
return x
3、Transition Layer
Transition Layer将不同DenseBlock之间进行连接的模块,主要功能是整合上一个DenseBlock获得的特征,并且缩小上一个DenseBlock的宽高,在Transition Layer中,一般会使用一个步长为2的AveragePooling2D缩小特征层的宽高。
实现代码为:
def transition_block(x, reduction, name):
bn_axis = 3
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_bn')(x)
x = layers.Activation('relu', name=name + '_relu')(x)
x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,
use_bias=False,
name=name + '_conv')(x)
x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
return x
网络实现代码
from keras.preprocessing import image
from keras.models import Model
from keras import layers
from keras.applications import imagenet_utils
from keras.applications.imagenet_utils import decode_predictions
from keras.utils.data_utils import get_file
from keras import backend
import numpy as np
BASE_WEIGTHS_PATH = (
'https://github.com/keras-team/keras-applications/'
'releases/download/densenet/')
DENSENET121_WEIGHT_PATH = (
BASE_WEIGTHS_PATH +
'densenet121_weights_tf_dim_ordering_tf_kernels.h5')
DENSENET169_WEIGHT_PATH = (
BASE_WEIGTHS_PATH +
'densenet169_weights_tf_dim_ordering_tf_kernels.h5')
DENSENET201_WEIGHT_PATH = (
BASE_WEIGTHS_PATH +
'densenet201_weights_tf_dim_ordering_tf_kernels.h5')
def dense_block(x, blocks, name):
for i in range(blocks):
x = conv_block(x, 32, name=name + '_block' + str(i + 1))
return x
def conv_block(x, growth_rate, name):
bn_axis = 3
x1 = layers.BatchNormalization(axis=bn_axis,
epsilon=1.001e-5,
name=name + '_0_bn')(x)
x1 = layers.Activation('relu', name=name + '_0_relu')(x1)
x1 = layers.Conv2D(4 * growth_rate, 1,
use_bias=False,
name=name + '_1_conv')(x1)
x1 = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_1_bn')(x1)
x1 = layers.Activation('relu', name=name + '_1_relu')(x1)
x1 = layers.Conv2D(growth_rate, 3,
padding='same',
use_bias=False,
name=name + '_2_conv')(x1)
x = layers.Concatenate(axis=bn_axis, name=name + '_concat')([x, x1])
return x
def transition_block(x, reduction, name):
bn_axis = 3
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
name=name + '_bn')(x)
x = layers.Activation('relu', name=name + '_relu')(x)
x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1,
use_bias=False,
name=name + '_conv')(x)
x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x)
return x
def DenseNet(blocks,
input_shape=None,
classes=1000,
**kwargs):
img_input = layers.Input(shape=input_shape)
bn_axis = 3
# 224,224,3 -> 112,112,64
x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)
x = layers.Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x)
x = layers.BatchNormalization(
axis=bn_axis, epsilon=1.001e-5, name='conv1/bn')(x)
x = layers.Activation('relu', name='conv1/relu')(x)
# 112,112,64 -> 56,56,64
x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)
x = layers.MaxPooling2D(3, strides=2, name='pool1')(x)
# 56,56,64 -> 56,56,64+32*block[0]
# Densenet121 56,56,64 -> 56,56,64+32*6 == 56,56,256
x = dense_block(x, blocks[0], name='conv2')
# 56,56,64+32*block[0] -> 28,28,32+16*block[0]
# Densenet121 56,56,256 -> 28,28,32+16*6 == 28,28,128
x = transition_block(x, 0.5, name='pool2')
# 28,28,32+16*block[0] -> 28,28,32+16*block[0]+32*block[1]
# Densenet121 28,28,128 -> 28,28,128+32*12 == 28,28,512
x = dense_block(x, blocks[1], name='conv3')
# Densenet121 28,28,512 -> 14,14,256
x = transition_block(x, 0.5, name='pool3')
# Densenet121 14,14,256 -> 14,14,256+32*block[2] == 14,14,1024
x = dense_block(x, blocks[2], name='conv4')
# Densenet121 14,14,1024 -> 7,7,512
x = transition_block(x, 0.5, name='pool4')
# Densenet121 7,7,512 -> 7,7,256+32*block[3] == 7,7,1024
x = dense_block(x, blocks[3], name='conv5')
x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='bn')(x)
x = layers.Activation('relu', name='relu')(x)
x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
x = layers.Dense(classes, activation='softmax', name='fc1000')(x)
inputs = img_input
if blocks == [6, 12, 24, 16]:
model = Model(inputs, x, name='densenet121')
elif blocks == [6, 12, 32, 32]:
model = Model(inputs, x, name='densenet169')
elif blocks == [6, 12, 48, 32]:
model = Model(inputs, x, name='densenet201')
else:
model = Model(inputs, x, name='densenet')
return model
def DenseNet121(input_shape=[224,224,3],
classes=1000,
**kwargs):
return DenseNet([6, 12, 24, 16],
input_shape, classes,
**kwargs)
def DenseNet169(input_shape=[224,224,3],
classes=1000,
**kwargs):
return DenseNet([6, 12, 32, 32],
input_shape, classes,
**kwargs)
def DenseNet201(input_shape=[224,224,3],
classes=1000,
**kwargs):
return DenseNet([6, 12, 48, 32],
input_shape, classes,
**kwargs)
def preprocess_input(x):
x /= 255.
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
x[..., 0] -= mean[0]
x[..., 1] -= mean[1]
x[..., 2] -= mean[2]
if std is not None:
x[..., 0] /= std[0]
x[..., 1] /= std[1]
x[..., 2] /= std[2]
return x
if __name__ == '__main__':
# model = DenseNet121()
# weights_path = get_file(
# 'densenet121_weights_tf_dim_ordering_tf_kernels.h5',
# DENSENET121_WEIGHT_PATH,
# cache_subdir='models',
# file_hash='9d60b8095a5708f2dcce2bca79d332c7')
model = DenseNet169()
weights_path = get_file(
'densenet169_weights_tf_dim_ordering_tf_kernels.h5',
DENSENET169_WEIGHT_PATH,
cache_subdir='models',
file_hash='d699b8f76981ab1b30698df4c175e90b')
# model = DenseNet201()
# weights_path = get_file(
# 'densenet201_weights_tf_dim_ordering_tf_kernels.h5',
# DENSENET201_WEIGHT_PATH,
# cache_subdir='models',
# file_hash='1ceb130c1ea1b78c3bf6114dbdfd8807')
model.load_weights(weights_path)
model.summary()
img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
print('Input image shape:', x.shape)
preds = model.predict(x)
print(np.argmax(preds))
print('Predicted:', decode_predictions(preds))
来源:https://blog.csdn.net/weixin_44791964/article/details/105472196
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- 时间久了,注册用户和朋友数据库里的废记录渐渐多了起来,尤其是电子邮件地址,请问有什么好的办法可以快速安全地将它们删除吗?试试下面这个办法,它
- 如题:只需要给定输出特征图的大小就好,其中通道数前后不发生变化。具体如下:AdaptiveAvgPool2d CLASStorch.nn.A
- 前言最近在看测试相关的内容,发现自动化测试很好玩,就决定做一个自动回复QQ消息的脚本(我很菜)1、需要安装的模块这个自动化脚本需要用到3个模
- 最近在用python连接sqlserver读取数据库,读取数据时候在本机电脑正常,但是把程序部署到服务器运行时一直报错“未发现数据源名称并且
- 本文实例讲述了Python使用matplotlib实现的图像读取、切割裁剪功能。分享给大家供大家参考,具体如下:# -*- coding:u
- 前言当今,随着计算机技术的发展,摄像头已经成为了人们生活中不可或缺的一部分。而Python作为一种流行的编程语言,也可以轻松地控制和操作摄像
- 下面是asp代码实现列出sql数据库中存储过程的功能,可自行添加其它功能:< HTML >< 
- 下面是一段产生log-normal分布的代码,以此进行说明。clear all;clc;for t=1:100 Traffic(t) =cu
- 前言如果你以前没有接触过面向对象的编程语言,那你可能需要先了解一些面向对象语言的一些基本特征,在头脑里头形成一个基本的面向对象的概念,这样有
- 1、操作步骤(1)打开文件读取整个文件函数open返回一个表示文件的对象,对象存储在infile中。关键字with在不需要访问文件时将其自动
- Matplotlib官网 如果想了解更多可查看官网。import numpy as np import matplotlib.py
- 前言一个非常神秘的魔术方法。这个方法非常不起眼,用途狭窄,我几乎从未注意过它,然而,当发现它可能是上述“定律”的唯一例外情况时,我认为值得再
- 目录实验环境依赖项安装编程实现浏览器有一个可以用于展示网页的窗口代码总结实验环境操作系统:Linux Mint编辑器:vim编程语言:pyt
- 如果要用某个开源框架,需要安装多个依赖包可以如下操作:如依赖文件形式如下(可以不要版本号):txt文件名为requirements.txt,
- 拿去给自己所思所念之人from turtle import *import timesetup(500, 500, startx=None,
- 一、图示上面为pdf截图内容,下面为转化后的word截图内容接下来,我们试试自己动作写这个工具吧!二、前期准备由于我们采用的是python进
- 准备工作首先是准备工作,导入需要使用的库,读取并创建数据表取名为loandata。import numpy as npimport pand
- 前面介绍过vSQLAlchemy中的 Engine 和 Connection,这两个对象用在row SQL (原生的sql语句)上操作,而
- 本文实例讲述了Python使用add_subplot与subplot画子图操作。分享给大家供大家参考,具体如下:子图:就是在一张figure
- 前言大家应该都有所了解,下面就简单介绍下Numpy,NumPy(Numerical Python)是一个用于科学计算第三方的Python包。