在keras下实现多个模型的融合方式
作者:小风风12580 发布时间:2023-06-03 17:14:59
标签:keras,模型,融合
在网上搜过发现关于keras下的模型融合框架其实很简单,奈何网上说了一大堆,这个东西官方文档上就有,自己写了个demo:
# Function:基于keras框架下实现,多个独立任务分类
# Writer: PQF
# Time: 2019/9/29
import numpy as np
from keras.layers import Input, Dense
from keras.models import Model
import tensorflow as tf
# 生成训练集
dataset_size = 128*3
rdm = np.random.RandomState(1)
X = rdm.rand(dataset_size,2)
Y1 = [[int(x1+x2<1)] for (x1,x2) in X]
Y2 = [[int(x1+x2*x2<0.5)] for (x1,x2) in X]
X_train = X[:-2]
Y_train1 = Y1[:-2]
Y_train2 = Y2[:-2]
X_test = X[-2:dataset_size]
Y_test1 = Y1[-2:dataset_size]
Y_test2 = Y2[-2:dataset_size]
#网络一
input = Input(shape=(2,))
x = Dense(units=16,activation='relu')(input)
output = Dense(units=1,activation='sigmoid',name='output1')(x)
#网络二
input2 = Input(shape=(2,))
x2 = Dense(units=16,activation='relu')(input2)
output2 = Dense(units=1,activation='sigmoid',name='output2')(x2)
#模型合并
model = Model(inputs=[input,input2],outputs=[output,output2])
model.summary()
model.compile(optimizer='rmsprop',loss='binary_crossentropy',loss_weights=[1.0,1.0])
model.fit([X_train,X_train],[Y_train1,Y_train2],batch_size=48,epochs=200)
print('x_test is :\n')
print(X_test)
print('y_test1 is :\n')
print(Y_test1)
print('y_test2 is :\n')
print(Y_test2)
predict = model.predict([X_test,X_test])
print('prediction is : \n')
print(predict[0])
print(predict[1])
补充知识:keras的融合层使用理解
最近开始研究U-net网络,其中接触到了融合层的概念,做个笔记。
上图为U-net网络,其中上采样层(绿色箭头)需要与下采样层池化层(红色箭头)层进行融合,要求每层的图片大小一致,维度依照融合的方式可以不同,融合之后输出的图片相较于没有融合层的网络,边缘处要清晰很多!
这时候就要用到keras的融合层概念(Keras中文文档https://keras.io/zh/)
文档中分别讲述了加减乘除的四中融合方式,这种方式要求两层之间shape必须一致。
重点讲述一下Concatenate(拼接)方式
拼接方式默认依照最后一维也就是通道来进行拼接
如同上图(128*128*64)与(128*128*128)进行Concatenate之后的shape为128*128*192
ps:
中文文档为老版本,最新版本的keras.layers.merge方法进行了整合
上图为新版本整合之后的方法,具体使用方法一看就懂,不再赘述。
来源:https://blog.csdn.net/weixin_43392276/article/details/101757173


猜你喜欢
- 看youa的源码发现的,原来flash可以有fallback content:<object type="applicati
- 目录1. 柱状图概述1.1什么是柱状图1.2柱状图使用场景1.3柱状图绘制步骤1.3案例展示2. 柱状图属性2.1柱状体颜色填充2.2状描边
- 一、地理编码与逆编码地理编码与逆编码表示的是地名地址与地理坐标(经纬度)互相转换的过程。其中,将地址信息映射为地理坐标的过程称之为地理编码;
- 最近研究微信API,发现个非常好用的python库:wxpy。wxpy基于itchat,使用了 Web 微信的通讯协议,实现了微信登录、收发
- 我就废话不多说了,大家还是直接看代码吧~one = tf.ones_like(label)zero = tf.zeros_like(labe
- 一、选取网址进行爬虫本次我们选取pixabay图片网站url=https://pixabay.com/二、选择图片右键选择查看元素来寻找图片
- <?php # 设置 $domain 为你的域名 (注意没有www) $domain = "aspxhome.com&quo
- lambda表达式python中形如:lambda parameters: expression称为lambda表达式,用于创建匿名函数,该
- 本文是对《Python Qt GUI快速编程》的第10章的例子剪贴板用Python3+PyQt5进行改写,分别对文本,图片和html文本的复
- 实例如下:#! /usr/bin/python# -*- coding: utf-8 -*-import osdef del_dir_tre
- 命令首先数据库迁移的两大命令: python manage.py makemigrations & python manage.py
- 一、安装1.从官网下载Linux版的Pycharm官网链接:https://www.jetbrains.com/pycharm/downlo
- 1 np.arange(),类似于range,通过指定开始值,终值和步长来创建表示等差数列的一维数组,注意该函数和range一样结果不包含终
- 现在大多数Centos6.x版本的系统python都是2.x,现因开发需求需要安装前端代码的构建工具glue,故必须要做python版本的升
- 将图片读入到Dom中,并将其存为xml文件1、需要命名空间using System.Text;using System.IO;using S
- BLOG地址:http://www.planabc.net/article.asp?id=107学习标准的朋友,一般都会在学习的过程中接触到
- 有时候我们用的一些pdf资料是没有目录的,这样找寻我们想到的东西比较麻烦。本篇文章就为大家带来python来生成pdf目录书签的方法。首先,
- 本文实例总结了JavaScript数组去重的方法。分享给大家供大家参考,具体如下:数组去重,一般都是在面试的时候才会碰到,一般是要求手写数组
- 我先给一个初步的表格吧,大家如果有什么意见,或有补充,欢迎提出。有些我没有用过,先不写了。 以下是我使用过的python IDE: 除了Py
- 今天在下脚本的时候遇到一个问题,比如有这样的一个字符串 t = "book123456",想把尾部的数字全部去掉,只留下