Keras实现将两个模型连接到一起
作者:木盏 发布时间:2021-07-10 07:24:08
标签:Keras,模型,连接
神经网络玩得越久就越会尝试一些网络结构上的大改动。
先说意图
有两个模型:模型A和模型B。模型A的输出可以连接B的输入。将两个小模型连接成一个大模型,A-B,既可以同时训练又可以分离训练。
流行的算法里经常有这么关系的两个模型,对GAN来说,生成器和判别器就是这样子;对VAE来说,编码器和解码器就是这样子;对目标检测网络来说,backbone和整体也是可以拆分的。所以,应用范围还是挺广的。
实现方法
首先说明,我的实现方法不一定是最佳方法。也是实在没有借鉴到比较好的方法,所以才自己手动写了一个。
第一步,我们有现成的两个模型A和B;我们想把A的输出连到B的输入,组成一个整体C。
第二步, 重构新模型C;我的方法是:读出A和B各有哪些layer,然后一层一层重新搭成C。
可以看一个自编码器的代码(本人所编写):
class AE:
def __init__(self, dim, img_dim, batch_size):
self.dim = dim
self.img_dim = img_dim
self.batch_size = batch_size
self.encoder = self.encoder_construct()
self.decoder = self.decoder_construct()
def encoder_construct(self):
x_in = Input(shape=(self.img_dim, self.img_dim, 3))
x = x_in
x = Conv2D(self.dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = Conv2D(self.dim, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(x)
x = BatchNormalization()(x)
x = LeakyReLU(0.2)(x)
x = GlobalAveragePooling2D()(x)
encoder = Model(x_in, x)
return encoder
def decoder_construct(self):
map_size = K.int_shape(self.encoder.layers[-2].output)[1:-1]
# print(type(map_size))
z_in = Input(shape=K.int_shape(self.encoder.output)[1:])
z = z_in
z_dim = self.dim
z = Dense(np.prod(map_size) * z_dim)(z)
z = Reshape(map_size + (z_dim,))(z)
z = Conv2DTranspose(z_dim // 2, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 4, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 8, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(z_dim // 16, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = BatchNormalization()(z)
z = Activation('relu')(z)
z = Conv2DTranspose(3, kernel_size=(5, 5), strides=(2, 2), padding='SAME')(z)
z = Activation('tanh')(z)
decoder = Model(z_in, z)
return decoder
def build_ae(self):
input_x = Input(shape=(self.img_dim, self.img_dim, 3))
x = input_x
for i in range(1, len(self.encoder.layers)):
x = self.encoder.layers[i](x)
for j in range(1, len(self.decoder.layers)):
x = self.decoder.layers[j](x)
y = x
auto_encoder = Model(input_x, y)
return auto_encoder
模型A就是这里的encoder,模型B就是这里的decoder。所以,连接的精髓在build_ae()函数,直接用for循环读出各层,然后一层一层重新构造新的模型,从而实现连接效果。因为keras也是基于图的框架,这个操作并不会很费时,因为没有实际地计算。
补充知识:keras得到每层的系数
使用keras搭建好一个模型,训练好,怎么得到每层的系数呢:
weights = np.array(model.get_weights())
print(weights)
print(weights[0].shape)
print(weights[1].shape)
这样系数就被存放到一个np中了。
来源:https://blog.csdn.net/leviopku/article/details/83510927


猜你喜欢
- 本文解决问题:批量删除多行txt文本中的内容。思路:1.找出需要删除行的 id(就是需要删除那些行,把这是第几行给记录下来。)2.将原文本内
- 一、前言我打开4399小游戏网,点开了一个不知名的游戏,唔,做寿司的,有材料在一边,客人过来后说出他们的要求,你按照菜单做好端给他便好~要怎
- 有一个比较有意思的传参方式:比如在 demo1.py 中指定 action='store_true'的时候:parser.a
- 最近帮人做了个贪吃蛇的游戏(交作业用),很简单,界面如下:开始界面:游戏中界面:是不是很简单、朴素。(欢迎大家访问GitHub)游戏是基于P
- 前言日常生活中,手残党们经常会把一些照片拍歪,比如拍个证件、试卷、PPT什么的,比如下面这本书的封面原本是个矩形,随手一拍就成了不规则四边形
- 前言本文主要给大家介绍了关于golang分页算法的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧示例代码如下://
- 本文实例讲述了Go语言中的匿名结构体用法。分享给大家供大家参考。具体实现方法如下:package main  
- 一、序言在分布式并发系统中,数据库与缓存数据一致性是一项富有挑战性的技术难点。本文将讨论数据库与缓存数据一致性问题,并提供通用的解决方案。假
- 本文实例总结了PHP session会话操作技巧。分享给大家供大家参考,具体如下:会话技术session将会话数据存储与服务器端,同时使会话
- Update Tb_Garden1 G Set Steward = (Select Id From Zyq.Tb_User U Where
- VSCode配置python调试环境很久之前的一个东东,翻出来看看VSCode配置python调试环境 * 1.下载p
- 1.作用域在python中,作用域分为两种:全局作用域和局部作用域。全局作用域是定义在文件级别的变量,函数名。而局部作用域,则是定义函数内部
- sorted 用于对集合进行排序(这里集合是对可迭代对象的一个统称,他们可以是列表、字典、set、甚至是字符串),它的功能非常强大1、对列表
- 本文实例讲述了python实现中文分词FMM算法。分享给大家供大家参考。具体分析如下:FMM算法的最简单思想是使用贪心算法向前找n个,如果这
- 0x00 识别涉及技术验证码识别涉及很多方面的内容。入手难度大,但是入手后,可拓展性又非常广泛,可玩性极强,成就感也很足。验证码图像处理验证
- var long2="1988-0w-07";alert(long2.substring(0,4)+"----
- 1.介绍Go语言中的测试依赖go test命令。编写测试代码和编写普通的Go代码过程是类似的,并不需要学习新的语法、规则或工具; go te
- 当用cmd命令行运行python文件时,我们知道可以通过>python pyfile.py来运行python文件,此时的输出会直接打印
- range()反向遍历的几种表达for i in range(10,0,-2):#有10 print(i)prin
- 一、先描述一下问题吧如下创建表时候报错了CREATE TABLE `xxx` ( `id` bigint(20) NOT NUL