使用keras实现孪生网络中的权值共享教程
作者:Master He 发布时间:2022-10-21 13:52:34
首先声明,这里的权值共享指的不是CNN原理中的共享权值,而是如何在构建类似于Siamese Network这样的多分支网络,且分支结构相同时,如何使用keras使分支的权重共享。
Functional API
为达到上述的目的,建议使用keras中的Functional API,当然Sequential 类型的模型也可以使用,本篇博客将主要以Functional API为例讲述。
keras的多分支权值共享功能实现,官方文档介绍
上面是官方的链接,本篇博客也是基于上述官方文档,实现的此功能。(插一句,keras虽然有中文文档,但中文文档已停更,且中文文档某些函数介绍不全,建议直接看英文官方文档)
不共享参数的模型
以MatchNet网络结构为例子,为方便显示,将卷积模块个数减为2个。首先是展示不共享参数的模型,以便观看完整的网络结构。
整体的网络结构如下所示:
代码包含两部分,第一部分定义了两个函数,FeatureNetwork()生成特征提取网络,ClassiFilerNet()生成决策网络或称度量网络。网络结构的可视化在博客末尾。在ClassiFilerNet()函数中,可以看到调用了两次FeatureNetwork()函数,keras.models.Model也被使用的两次,因此生成的input1和input2是两个完全独立的模型分支,参数是不共享的。
from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model
# ---------------------函数功能区-------------------------
def FeatureNetwork():
"""生成特征提取网络"""
"""这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
models = Activation('relu')(models)
models = MaxPool2D(pool_size=(3, 3))(models)
models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
# models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
models = Activation('relu')(models)
models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
models = Activation('relu')(models)
models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
models = Activation('relu')(models)
models = Flatten()(models)
models = Dense(512)(models)
models = Activation('relu')(models)
model = Model(inputs=inp, outputs=models)
return model
def ClassiFilerNet(): # add classifier Net
"""生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
input1 = FeatureNetwork() # 孪生网络中的一个特征提取
input2 = FeatureNetwork() # 孪生网络中的另一个特征提取
for layer in input2.layers: # 这个for循环一定要加,否则网络重名会出错。
layer.name = layer.name + str("_2")
inp1 = input1.input
inp2 = input2.input
merge_layers = concatenate([input1.output, input2.output]) # 进行融合,使用的是默认的sum,即简单的相加
fc1 = Dense(1024, activation='relu')(merge_layers)
fc2 = Dense(1024, activation='relu')(fc1)
fc3 = Dense(2, activation='softmax')(fc2)
class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
return class_models
# ---------------------主调区-------------------------
matchnet = ClassiFilerNet()
matchnet.summary() # 打印网络结构
plot_model(matchnet, to_file='G:/csdn攻略/picture/model.png') # 网络结构输出成png图片
共享参数的模型
FeatureNetwork()的功能和上面的功能相同,为方便选择,在ClassiFilerNet()函数中加入了判断是否使用共享参数模型功能,令reuse=True,便使用的是共享参数的模型。
关键地方就在,只使用的一次Model,也就是说只创建了一次模型,虽然输入了两个输入,但其实使用的是同一个模型,因此权重共享的。
from keras.models import Sequential
from keras.layers import merge, Conv2D, MaxPool2D, Activation, Dense, concatenate, Flatten
from keras.layers import Input
from keras.models import Model
from keras.utils import np_utils
import tensorflow as tf
import keras
from keras.datasets import mnist
import numpy as np
from keras.utils import np_utils
from keras.callbacks import EarlyStopping, ModelCheckpoint, TensorBoard, ReduceLROnPlateau
from keras.utils.vis_utils import plot_model
# ----------------函数功能区-----------------------
def FeatureNetwork():
"""生成特征提取网络"""
"""这是根据,MNIST数据调整的网络结构,下面注释掉的部分是,原始的Matchnet网络中feature network结构"""
inp = Input(shape = (28, 28, 1), name='FeatureNet_ImageInput')
models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
models = Activation('relu')(models)
models = MaxPool2D(pool_size=(3, 3))(models)
models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
# models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
models = Activation('relu')(models)
models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
models = Activation('relu')(models)
models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
models = Activation('relu')(models)
# models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
# models = Activation('relu')(models)
# models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
models = Flatten()(models)
models = Dense(512)(models)
models = Activation('relu')(models)
model = Model(inputs=inp, outputs=models)
return model
def ClassiFilerNet(reuse=False): # add classifier Net
"""生成度量网络和决策网络,其实maychnet是两个网络结构,一个是特征提取层(孪生),一个度量层+匹配层(统称为决策层)"""
if reuse:
inp = Input(shape=(28, 28, 1), name='FeatureNet_ImageInput')
models = Conv2D(filters=24, kernel_size=(3, 3), strides=1, padding='same')(inp)
models = Activation('relu')(models)
models = MaxPool2D(pool_size=(3, 3))(models)
models = Conv2D(filters=64, kernel_size=(3, 3), strides=1, padding='same')(models)
# models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
models = Activation('relu')(models)
models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
models = Activation('relu')(models)
models = Conv2D(filters=96, kernel_size=(3, 3), strides=1, padding='valid')(models)
models = Activation('relu')(models)
# models = Conv2D(64, kernel_size=(3, 3), strides=2, padding='valid')(models)
# models = Activation('relu')(models)
# models = MaxPool2D(pool_size=(3, 3), strides=(2, 2))(models)
models = Flatten()(models)
models = Dense(512)(models)
models = Activation('relu')(models)
model = Model(inputs=inp, outputs=models)
inp1 = Input(shape=(28, 28, 1)) # 创建输入
inp2 = Input(shape=(28, 28, 1)) # 创建输入2
model_1 = model(inp1) # 孪生网络中的一个特征提取分支
model_2 = model(inp2) # 孪生网络中的另一个特征提取分支
merge_layers = concatenate([model_1, model_2]) # 进行融合,使用的是默认的sum,即简单的相加
else:
input1 = FeatureNetwork() # 孪生网络中的一个特征提取
input2 = FeatureNetwork() # 孪生网络中的另一个特征提取
for layer in input2.layers: # 这个for循环一定要加,否则网络重名会出错。
layer.name = layer.name + str("_2")
inp1 = input1.input
inp2 = input2.input
merge_layers = concatenate([input1.output, input2.output]) # 进行融合,使用的是默认的sum,即简单的相加
fc1 = Dense(1024, activation='relu')(merge_layers)
fc2 = Dense(1024, activation='relu')(fc1)
fc3 = Dense(2, activation='softmax')(fc2)
class_models = Model(inputs=[inp1, inp2], outputs=[fc3])
return class_models
如何看是否真的是权值共享呢?直接对比特征提取部分的网络参数个数!
不共享参数模型的参数数量:
共享参数模型的参数总量
共享参数模型中的特征提取部分的参数量为:
由于截图限制,不共享参数模型的特征提取网络参数数量不再展示。其实经过计算,特征提取网络部分的参数数量,不共享参数模型是共享参数的两倍。两个网络总参数量的差值就是,共享模型中,特征提取部分的参数的量
网络结构可视化
不共享权重的网络结构
共享参数的网络结构,其中model_1代表的就是特征提取部分。
来源:https://blog.csdn.net/qq_35826213/article/details/86313469
猜你喜欢
- 目录前言创建对象方式一:方式二:更新对象方式一:方式二:方式三:查询检索全部对象:条件过滤:方式一:方式二:检索单个对象:总结前言上篇已经介
- 基于Python中求和函数sum的用法详解今天在看《集体编程智慧》这本书的时候,看到一段Python代码,当时是百思不得其解,总觉得是书中排
- SQL Server2000中,如果数据库文件(非系统数据库文件)遇到错误的时候,我们该怎么办。以下是笔者以前的笔记。仅适用于非master
- 导语哈喽!我是木木子,又到了今日更新时刻!我们来看看写什么呢?小编有个好兄弟最近在追妹子,跟妹子打得火热!就差临门一脚了,这一jio我帮忙补
- return 语句就是讲结果返回到调用的地方,并把程序的控制权一起返回程序运行到所遇到的第一个return即返回(退出def块),不会再运行
- 用 ASP (using jscript) 在服务端创建 GUID 的代码如下:function GUID(){ ret
- 你有没有觉得你的CSS样式表文件过于臃肿?其实如果你注意并培养一些比较好的CSS书写习惯,我想你的CSS样式表过于”肥胖”的问题会得到很好的
- 本文介绍了python的构建工具setup.py,分享个大家,具体如下:一、构建工具setup.py的应用场景在安装python的相关模块和
- 这里分析了php的简单防盗链实现方法。分享飞大家供大家参考。具体如下:index.php页面如下:<html><head&
- 本文实例讲述了微信小程序module.exports模块化操作。分享给大家供大家参考,具体如下:文件 目录如上图:看到网上写的模块化都比较复
- 当Python中用到双重for循环设计的时候我一般会使用循环的嵌套,但是在Python中其实还存在另一种技巧——for复合语句。简单写一个小
- python是免费的么?python是免费的,也就是开源的。编程软件的盈利方式就是你使用它, 用的人越多越值钱。注:Python 是一个高层
- SQL Server vNext CTP 1.2安装教程:此安装过程参考微软官方的安装文档:https://docs.micro
- 在开发C/S结构的大型数据库应用软件时,一般情况下,软件开发人员和数据库设计人员并不是同一个人,这就需要协商好一些即可由程序设
- 前提是已设置ANDROID_HOME环境变量,使用aapt工具获取apk的信息,保存至脚本所在目录下的PackageInfo.txt文件中:
- 外部直接执行python文件时,我们有时需要获得命令行的参数获得命令行参数的两种方式1、通过sys.argvsys.argv:获得一个参数列
- 今天在群里,熊猫君提议整理一个帖子,一方面为初学者提供一个入门指南,另一方面也象借此和已经在从事这个行业进行一点交流。下面是我从事这个行当多
- 1、update delete insert 这种语句都需要commit或者直接在连接数据库的时候加上autocommit=Trueimpo
- 一、效果图二、必要工具Python3.7pycharm2019再然后配置它的文件,设置游戏屏幕的大小,图片路径。代码如下''
- 引入:通常,钓鱼网站本质是本质搭建一个跟正常网站一模一样的页面,用户在该页面上完成转账功能转账的请求确实是朝着正常网站的服务端提交,唯一不同