在keras下实现多个模型的融合方式
作者:小风风12580 发布时间:2023-06-03 17:14:59
标签:keras,模型,融合
在网上搜过发现关于keras下的模型融合框架其实很简单,奈何网上说了一大堆,这个东西官方文档上就有,自己写了个demo:
# Function:基于keras框架下实现,多个独立任务分类
# Writer: PQF
# Time: 2019/9/29
import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
import tensorflow as tf
# 生成训练集
dataset_size = 128*3
rdm = np.random.RandomState(1)
X = rdm.rand(dataset_size,2)
Y1 = [[int(x1+x2<1)] for (x1,x2) in X]
Y2 = [[int(x1+x2*x2<0.5)] for (x1,x2) in X]
X_train = X[:-2]
Y_train1 = Y1[:-2]
Y_train2 = Y2[:-2]
X_test = X[-2:dataset_size]
Y_test1 = Y1[-2:dataset_size]
Y_test2 = Y2[-2:dataset_size]
#网络一
input = Input(shape=(2,))
x = Dense(units=16,activation='relu')(input)
output = Dense(units=1,activation='sigmoid',name='output1')(x)
#网络二
input2 = Input(shape=(2,))
x2 = Dense(units=16,activation='relu')(input2)
output2 = Dense(units=1,activation='sigmoid',name='output2')(x2)
#模型合并
model = Model(inputs=[input,input2],outputs=[output,output2])
model.summary()
model.compile(optimizer='rmsprop',loss='binary_crossentropy',loss_weights=[1.0,1.0])
model.fit([X_train,X_train],[Y_train1,Y_train2],batch_size=48,epochs=200)
print('x_test is :\n')
print(X_test)
print('y_test1 is :\n')
print(Y_test1)
print('y_test2 is :\n')
print(Y_test2)
predict = model.predict([X_test,X_test])
print('prediction is : \n')
print(predict[0])
print(predict[1])
补充知识:keras的融合层使用理解
最近开始研究U-net网络,其中接触到了融合层的概念,做个笔记。
上图为U-net网络,其中上采样层(绿色箭头)需要与下采样层池化层(红色箭头)层进行融合,要求每层的图片大小一致,维度依照融合的方式可以不同,融合之后输出的图片相较于没有融合层的网络,边缘处要清晰很多!
这时候就要用到keras的融合层概念(Keras中文文档https://keras.io/zh/)
文档中分别讲述了加减乘除的四中融合方式,这种方式要求两层之间shape必须一致。
重点讲述一下Concatenate(拼接)方式
拼接方式默认依照最后一维也就是通道来进行拼接
如同上图(128*128*64)与(128*128*128)进行Concatenate之后的shape为128*128*192
ps:
中文文档为老版本,最新版本的keras.layers.merge方法进行了整合
上图为新版本整合之后的方法,具体使用方法一看就懂,不再赘述。
来源:https://blog.csdn.net/weixin_43392276/article/details/101757173
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- python发送icmp echo requesy请求import socketimport structdef checksum(sour
- 目前可实现:MD5算法、SHA256算法、先MD5后SHA256、先SHA256后MD5、两次MD5、两次SHA256、前8位MD5算法后8
- SQL Server有几个版本都在使用中——4.2, 6.0, 6.5, 7.0, 2000,以及2
- 方法1: 用file_get_contents 以get方式获取内容:<?php$url='https://www.aspxh
- 本文实例讲述了PHP实现无限极分类的两种方式。分享给大家供大家参考,具体如下:面试的时候被问到无限极分类的设计和实现,比较常见的做法是在建表
- 网上有这样一道题目:一个字符串String=“adadfdfseffserfefsefseetsdg”,找出里面出现次数最多的字母和出现的次
- 列表是Python中最基本的数据结构,列表是最常用的Python数据类型,列表的数据项不需要具有相同的类型。列表中的每个元素都分配一个数字
- 当一个页面上有一百个表单项,你是怎么获取上面的值勤的?这是一段简单的代码,你试试这段代码,试过后,欢迎留言说一下你的想法?index.asp
- 下面示例代码是防止用网页刷新过快,如果多个页面使用,最好将<%...%>代码存为一个asp文件,在需要的页面最前面include
- 以前写过《 10条影响CSS渲染速度的写法与建议》,今天放些数据出来,供参考;首先说明一点,CSS对网页的最后渲染出来的速度影响非
- 昨天我问过这个问题怎么用ADODB.Stream来读取或写入文件,而不是用fso,不过没人回答到点上,今天搞定了.贴出来给觉得有用的朋友,希
- 【原文地址】 Fixes for Common VS 2008 and .NET 3.5 Beta2 Issu
- 1005:创建表失败1006:创建数据库失败1007:数据库已存在,创建数据库失败1008:数据库不存在,删除数据库失败1009:不能删除数
- 阅读上一篇:W3C优质网页小贴士(一) 使用 alt 属性描述每幅图像alt 属性有什么用?alt 属性可以在一系列标签中使用(如
- 当你使用UPDATE, INSERT, DELETE语句更新数据的时候,你就改变了两个地方的数据:log buffer和data buffe
- 首先,了解下原理。1,提供文本框进行查询内容的输入2,将查询信息提交页面程序处理3,程序页主要作用:接受查询信息,根据此信息调用特定的SQL
- 当你在浏览网页时,看到一个很漂亮的特效,你查看源代码时看到的是一队乱码,那多扫兴呀!根据本人的研究,总结出了三种解密方法,与大家分享!!方法
- 前言损失函数在机器学习中用于表示预测值与真实值之间的差距。一般而言,大多数机器学习模型都会通过一定的优化器来减小损失函数从而达到优化预测机器
- 在客户端,Get方式在通过URL提交数据,数据在URL中可以看到;POST方式,数据放置在HTML HEADER内提交。GET方式提交的数据
- 一、意义:当我们使用一个数据库时,总希望数据库的内容是可靠的、正确的,但由于计算机系统的故障(硬件故障、网络故障、进程故障和系统故障)影响数