Keras 使用 Lambda层详解
作者:驿无边 发布时间:2021-08-11 20:12:40
标签:Keras,Lambda
我就废话不多说了,大家还是直接看代码吧!
from tensorflow.python.keras.models import Sequential, Model
from tensorflow.python.keras.layers import Dense, Flatten, Conv2D, MaxPool2D, Dropout, Conv2DTranspose, Lambda, Input, Reshape, Add, Multiply
from tensorflow.python.keras.optimizers import Adam
def deconv(x):
height = x.get_shape()[1].value
width = x.get_shape()[2].value
new_height = height*2
new_width = width*2
x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
return x_resized
def Generator(scope='generator'):
imgs_noise = Input(shape=inputs_shape)
x = Conv2D(filters=32, kernel_size=(9,9), strides=(1,1), padding='same', activation='relu')(imgs_noise)
x = Conv2D(filters=64, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
x = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same', activation='relu')(x)
x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x)
x1 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x1)
x2 = Add()([x1, x])
x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x2)
x3 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x3)
x4 = Add()([x3, x2])
x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x4)
x5 = Conv2D(filters=128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu')(x5)
x6 = Add()([x5, x4])
x = MaxPool2D(pool_size=(2,2))(x6)
x = Lambda(deconv)(x)
x = Conv2D(filters=64, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
x = Lambda(deconv)(x)
x = Conv2D(filters=32, kernel_size=(3, 3), strides=(1,1), padding='same',activation='relu')(x)
x = Lambda(deconv)(x)
x = Conv2D(filters=3, kernel_size=(3, 3), strides=(1, 1), padding='same',activation='tanh')(x)
x = Lambda(lambda x: x+1)(x)
y = Lambda(lambda x: x*127.5)(x)
model = Model(inputs=imgs_noise, outputs=y)
model.summary()
return model
my_generator = Generator()
my_generator.compile(loss='binary_crossentropy', optimizer=Adam(0.7, decay=1e-3), metrics=['accuracy'])
补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案
一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。保存时会报
TypeError: can't pickle _thread.RLock objects
二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。
from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io
def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
if not os.path.exists(output_dir):
os.mkdir(output_dir)
h5_model = build_model()
h5_model.load_weights(h5_weight_path)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i], out_prefix + str(i + 1))
model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
sess = K.get_session()
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)
def build_model():
inputs = Input(shape=(784,), name='input_img')
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
h5_model = Model(inputs=inputs, outputs=y)
return h5_model
if __name__ == '__main__':
if len(sys.argv) == 3:
# usage: python3 h5_to_pb.py h5_weight_path output_dir
h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])
来源:https://blog.csdn.net/a4775019136/article/details/99336591


猜你喜欢
- 很多朋友在论坛和留言区域问mysql在什么情况下才需要进行分库分表,以及采用何种设计方式才是最优的选择,根据这些问题,小编为大家整理了关于M
- 共存问题我之前一直使用的是SQL2012版本的数据库管理工具,为了与时俱进,我也尝试更新一下版本,当然SQLServer管理工具是可以多版本
- 安装数据可视化模块matplotlib:pip install matplotlib导入matplotlib模块下的pyplot1 折线图f
- 前言 1. 概述共享坐标轴就是几幅子图之间共享x轴或y轴,这一部分主要了解如何在利用matplotlib制图时共享坐标轴。pyplot.s
- private void Button1_Click(object sender, System.E
- Python中包含了许多内建的语言特性,它们使得代码简洁且易于理解。这些特性包括列表/集合/字典推导式,属性(property)、以及装饰器
- 1 GitHub创建作为图床的仓库1.1 在GitHub中创建一个仓库注意仓库要是public的,不然上传的图片还是无法使用的。如果不知道怎
- 本文实例讲述了Go语言对字符串进行MD5加密的方法。分享给大家供大家参考。具体实现方法如下:package mainimport (&nbs
- 一些小技巧1. 如何查出效率低的语句?在MySQL下,在启动参数中设置 --log-slow-queries=[文件名],就可以在指定的日志
- 说明本文根据https://github.com/liuchengxu/blockchain-tutorial的内容,用python实现的,
- 前言如果电脑是第一次安装MySQL,一般不会出现这样的报错。如下图所示。starting the server失败,通常是因为上次安装的该软
- 本实例使用的mysql版本为mysql-8.0.15-winx641、下载zip包官网地址:https://dev.mysql.com/do
- 使用django实现注册登录的话,注册登录都有现成的代码,主要是自带的User字段只有(email,username,password),所
- jqGrid是一个优秀的基于jQuery的DataGrid框架,想必大伙儿也不陌生,网上基于ASP的资料很少,我提供一个,数据格式是json
- 前面我们讲了 TCP 编程,我们知道 TCP 可以建立可靠连接,并且通信双方都可以以流的形式发送数据。本文我们再来介绍另一个常用的协议–UD
- Exec sp_droplinkedsrvlogin ZYB,Null --删除映射(录与链接服务器上远程登录之间的映射) Exec sp_
- 1. 问题描述水仙花数也被称为超完全数字不变数、自恋数、自幂数、阿姆斯壮数或阿姆斯特朗数,水仙花数是指一个3位数,它的每个位上的数字的3次幂
- 1、查看数据库的字符集数据库的字符集必须和Linux下设置的环境变量一致,不然会有乱码。以下两个sql语句都可以查到:select * fr
- 许多网站缺乏针对性和友好的导航设计,难以找到连接到相关网页的路径,也没有提供有助于让访客/用户找到所需信息的帮助,用户体验非常糟糕。本期薯片
- 前段时间因为忙一些其它的事情,分享的有些少,最近学习一下redis在Go语言开发中的应用。一、理论知识Redis是一个开源的、使用C语言编写