Keras函数式(functional)API的使用方式
作者:黄然大悟 发布时间:2023-09-19 05:52:15
多层感知器(Multilayer Perceptron)
定义了用于二分类的多层感知器模型。
模型输入32维特征,经过三个全连接层,每层使用relu线性激活函数,并且在输出层中使用sigmoid激活函数,最后用于二分类。
##------ Multilayer Perceptron ------##
from keras.models import Model
from keras.layers import Input, Dense
from keras import backend as K
K.clear_session()
# MLP model
x = Input(shape=(32,))
hidden1 = Dense(10, activation='relu')(x)
hidden2 = Dense(20, activation='relu')(hidden1)
hidden3 = Dense(10, activation='relu')(hidden2)
output = Dense(1, activation='sigmoid')(hidden3)
model = Model(inputs=x, outputs=output)
# summarize layers
model.summary()
模型的结构和参数如下:
卷积神经网络(Convolutional Neural Network)
定义用于图像分类的卷积神经网络。
该模型接收3通道的64×64图像作为输入,然后经过两个卷积和池化层的序列作为特征提取器,接着过一个全连接层,最后输出层过softmax激活函数进行10个类别的分类。
##------ Convolutional Neural Network ------##
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K
K.clear_session()
# CNN model
x = Input(shape=(64,64,3))
conv1 = Conv2D(16, (5,5), activation='relu')(x)
pool1 = MaxPooling2D((2,2))(conv1)
conv2 = Conv2D(32, (3,3), activation='relu')(pool1)
pool2 = MaxPooling2D((2,2))(conv2)
conv3 = Conv2D(32, (3,3), activation='relu')(pool2)
pool3 = MaxPooling2D((2,2))(conv3)
flat = Flatten()(pool3)
hidden1 = Dense(512, activation='relu')(flat)
output = Dense(10, activation='softmax')(hidden1)
model = Model(inputs=x, outputs=output)
# summarize layers
model.summary()
模型的结构和参数如下:
循环神经网络(Recurrent Neural Network)
定义一个用于文本序列分类的LSTM网络。
该模型需要100个时间步长作为输入,然后经过一个Embedding层,每个时间步变成128维特征表示,然后经过一个LSTM层,LSTM输出过一个全连接层,最后输出用sigmoid激活函数用于进行二分类预测。
##------ Recurrent Neural Network ------##
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, LSTM, Embedding
from keras import backend as K
K.clear_session()
VOCAB_SIZE = 10000
EMBED_DIM = 128
x = Input(shape=(100,), dtype='int32')
embedding = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=True)(x)
hidden1 = LSTM(64)(embedding)
hidden2 = Dense(32, activation='relu')(hidden1)
output = Dense(1, activation='sigmoid')(hidden2)
model = Model(inputs=x, outputs=output)
# summarize layers
model.summary()
模型的结构和参数如下:
Bidirectional recurrent neural network
定义一个双向循环神经网络,可以用来完成序列标注等任务,相比上面的LSTM网络,多了一个反向的LSTM,其它设置一样。
##------ Bidirectional recurrent neural network ------##
from keras.models import Model
from keras.layers import Input, Embedding
from keras.layers import Dense, LSTM, Bidirectional
from keras import backend as K
K.clear_session()
VOCAB_SIZE = 10000
EMBED_DIM = 128
HIDDEN_SIZE = 64
# input layer
x = Input(shape=(100,), dtype='int32')
# embedding layer
embedding = Embedding(VOCAB_SIZE, EMBED_DIM, mask_zero=True)(x)
# BiLSTM layer
hidden = Bidirectional(LSTM(HIDDEN_SIZE, return_sequences=True))(embedding)
# prediction layer
output = Dense(10, activation='softmax')(hidden)
model = Model(inputs=x, outputs=output)
model.summary()
模型的结构和参数如下:
共享输入层模型(Shared Input Layer Model)
定义了具有不同大小内核的多个卷积层来解释图像输入。
该模型采用尺寸为64×64像素的3通道图像。
有两个共享此输入的CNN特征提取子模型; 第一个内核大小为5x5,第二个内核大小为3x3。
把提取的特征展平为向量然后拼接成一个长向量,然后过一个全连接层,最后输出层完成10分类。
##------ Shared Input Layer Model ------##
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D, Concatenate
from keras import backend as K
K.clear_session()
# input layer
x = Input(shape=(64,64,3))
# first feature extractor
conv1 = Conv2D(32, (3,3), activation='relu')(x)
pool1 = MaxPooling2D((2,2))(conv1)
flat1 = Flatten()(pool1)
# second feature extractor
conv2 = Conv2D(16, (5,5), activation='relu')(x)
pool2 = MaxPooling2D((2,2))(conv2)
flat2 = Flatten()(pool2)
# merge feature
merge = Concatenate()([flat1, flat2])
# interpretation layer
hidden1 = Dense(128, activation='relu')(merge)
# prediction layer
output = Dense(10, activation='softmax')(merge)
model = Model(inputs=x, outputs=output)
model.summary()
模型的结构和参数如下:
Shared Feature Extraction Layer
定义一个共享特征抽取层的模型,这里共享的是LSTM层的输出,具体共享参见代码
##------ Shared Feature Extraction Layer ------##
from keras.models import Model
from keras.layers import Input, Embedding
from keras.layers import Dense, LSTM, Concatenate
from keras import backend as K
K.clear_session()
# input layer
x = Input(shape=(100,32))
# feature extraction
extract1 = LSTM(64)(x)
# first interpretation model
interp1 = Dense(32, activation='relu')(extract1)
# second interpretation model
interp11 = Dense(64, activation='relu')(extract1)
interp12 = Dense(32, activation='relu')(interp11)
# merge interpretation
merge = Concatenate()([interp1, interp12])
# output layer
output = Dense(10, activation='softmax')(merge)
model = Model(inputs=x, outputs=output)
model.summary()
模型的结构和参数如下:
多输入模型(Multiple Input Model)
定义有两个输入的模型,这里测试的是输入两张图片,一个输入是单通道的64x64,另一个是3通道的32x32,两个经过卷积层、池化层后,展平拼接,最后进行二分类。
##------ Multiple Input Model ------##
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, MaxPooling2D, Concatenate
from keras import backend as K
K.clear_session()
# first input model
input1 = Input(shape=(64,64,1))
conv11 = Conv2D(32, (5,5), activation='relu')(input1)
pool11 = MaxPooling2D(pool_size=(2,2))(conv11)
conv12 = Conv2D(16, (3,3), activation='relu')(pool11)
pool12 = MaxPooling2D(pool_size=(2,2))(conv12)
flat1 = Flatten()(pool12)
# second input model
input2 = Input(shape=(32,32,3))
conv21 = Conv2D(32, (5,5), activation='relu')(input2)
pool21 = MaxPooling2D(pool_size=(2,2))(conv21)
conv22 = Conv2D(16, (3,3), activation='relu')(pool21)
pool22 = MaxPooling2D(pool_size=(2,2))(conv22)
flat2 = Flatten()(pool22)
# merge input models
merge = Concatenate()([flat1, flat2])
# interpretation model
hidden1 = Dense(20, activation='relu')(merge)
output = Dense(1, activation='sigmoid')(hidden1)
model = Model(inputs=[input1, input2], outputs=output)
model.summary()
模型的结构和参数如下:
多输出模型(Multiple Output Model)
定义有多个输出的模型,以文本序列输入LSTM网络为例,一个输出是对文本的分类,另外一个输出是对文本进行序列标注。
##------ Multiple Output Model ------ ##
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense, Flatten, TimeDistributed, LSTM
from keras.layers import Conv2D, MaxPooling2D, Concatenate
from keras import backend as K
K.clear_session()
x = Input(shape=(100,1))
extract = LSTM(10, return_sequences=True)(x)
class11 = LSTM(10)(extract)
class12 = Dense(10, activation='relu')(class11)
output1 = Dense(1, activation='sigmoid')(class12)
output2 = TimeDistributed(Dense(1, activation='linear'))(extract)
model = Model(inputs=x, outputs=[output1, output2])
model.summary()
模型的结构和参数如下:
参考
[1] https://machinelearningmastery.com/keras-functional-api-deep-learning/
[2] https://keras.io/getting-started/functional-api-guide/
[3] https://tensorflow.google.cn/alpha/guide/keras/functional
来源:https://blog.csdn.net/huanghaocs/article/details/90574486


猜你喜欢
- 可以通过model.state_dict()或者model.named_parameters()函数查看现在的全部可训练参数(包括通过继承得
- Python跑循环时内存泄露今天在用Tensorflow跑回归做测试时,仅仅需要循环四千多次 (补充说一句,我在个人PC上跑的)。运行以后,
- 选定基类后,就可以创建它的子类了。是否使用基类完全由你决定。有时,你可能想创建一个不能直接使用的基类,它只是用于给子类提供通用的函数。在这种
- python字符串,元组,列表,字典互相转换直接给大家上代码实例#-*-coding:utf-8-*- #1、字典dict = {'
- 前言小程序支持webview以后,我们开发的好多h5页面,就可以直接在小程序里使用了,比如我们开发的微信商城,文章详情页,商品详情页,就可以
- 1、ModuleNotFoundError: No module named ‘scipy.spatial.transf
- 一、批量新建并保存工作簿import xlwings as xw # 导入库# 启动Excel程序,但不新建工作
- 目录方法一:直接调用函数运行方法二:使用偏函数来执行方法三:使用 eval 动态执行方法四:使用 getattr 动态获取执行方法五:使用类
- 这一篇复习一下ECMAScript规范中的基础语法,英文好的朋友可以直接阅读官方文档。JavaScript本质上也是一种类C语言,熟悉C语言
- 谷歌的potobuf不说了,它很牛B,但是对客户端对象不支持,比如JavaScript就读取不了。Jil很牛,比Newtonsoft.Jso
- 之前在豆瓣上听到有友邻在抱怨卓越的配送速度慢得跟蜗牛一样,超过配送时间期限几天还没送到,当时不太相信,因为此前在卓越网上购买的物品基本上是在
- update()方法添加键 - 值对到字典dict2。此函数不返回任何值。语法以下是update()方法的语法:dict.upd
- ASP与MySQL的连接ASP和MySQL连接目前有两种方法:一种方法是使用MySQLX之类的组件,不过这种连接方法需要支付一定的费用;另外
- 进行访问MySQL数据库的方法有很多种,下面将向大家介绍一些很简单实用的用的方法和示例与大家一起分享。方法一:使用MYSQL推出的MySQL
- 本文实例为大家分享了Python3实现汉语转换为汉语拼音的具体代码,供大家参考,具体内容如下工具: Python3.6.2,pycharm1
- FlashPaper 是Macromedia推出的一款电子文档类工具,通过使用本程序,你可以将需要的文档通过简单的设置转换为SWF格式的Fl
- Function closeUBB(strContent) '*************************
- 本文实例讲述了python集合的创建、添加及删除操作。分享给大家供大家参考,具体如下:集合时无序可变的序列,集合中的元素放在{}内,集合中的
- 线程池的概念是什么?在面向对象编程中,创建和销毁对象是很费时间的,因为创建一个对象要获取内存资源或者其它更多资源。在Java中更是 如此,虚
- codecs在读取文件时,发生错误:UnicodeDecodeError: 'utf-8' codec can't