解决Keras 中加入lambda层无法正常载入模型问题
作者:机器玄学实践者 发布时间:2022-02-21 03:41:11
标签:Keras,lambda,载入,模型
刚刚解决了这个问题,现在记录下来
问题描述
当使用lambda层加入自定义的函数后,训练没有bug,载入保存模型则显示Nonetype has no attribute 'get'
问题解决方法:
这个问题是由于缺少config信息导致的。lambda层在载入的时候需要一个函数,当使用自定义函数时,模型无法找到这个函数,也就构建不了。
m = load_model(path,custom_objects={"reduce_mean":self.reduce_mean,"slice":self.slice})
其中,reduce_mean 和slice定义如下
def slice(self,x, turn):
""" Define a tensor slice function
"""
return x[:, turn, :, :]
def reduce_mean(self, X):
return K.mean(X, axis=-1)
补充知识:含有Lambda自定义层keras模型,保存遇到的问题及解决方案
一,许多应用,keras含有的层已经不能满足要求,需要透过Lambda自定义层来实现一些layer,这个情况下,只能保存模型的权重,无法使用model.save来保存模型。
保存时会报
TypeError: can't pickle _thread.RLock objects
二,解决方案,为了便于后续的部署,可以转成tensorflow的PB进行部署。
from keras.models import load_model
import tensorflow as tf
import os, sys
from keras import backend as K
from tensorflow.python.framework import graph_util, graph_io
def h5_to_pb(h5_weight_path, output_dir, out_prefix="output_", log_tensorboard=True):
if not os.path.exists(output_dir):
os.mkdir(output_dir)
h5_model = build_model()
h5_model.load_weights(h5_weight_path)
out_nodes = []
for i in range(len(h5_model.outputs)):
out_nodes.append(out_prefix + str(i + 1))
tf.identity(h5_model.output[i], out_prefix + str(i + 1))
model_name = os.path.splitext(os.path.split(h5_weight_path)[-1])[0] + '.pb'
sess = K.get_session()
init_graph = sess.graph.as_graph_def()
main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)
graph_io.write_graph(main_graph, output_dir, name=model_name, as_text=False)
if log_tensorboard:
from tensorflow.python.tools import import_pb_to_tensorboard
import_pb_to_tensorboard.import_to_tensorboard(os.path.join(output_dir, model_name), output_dir)
def build_model():
inputs = Input(shape=(784,), name='input_img')
x = Dense(64, activation='relu')(inputs)
x = Dense(64, activation='relu')(x)
y = Dense(10, activation='softmax')(x)
h5_model = Model(inputs=inputs, outputs=y)
return h5_model
if __name__ == '__main__':
if len(sys.argv) == 3:
# usage: python3 h5_to_pb.py h5_weight_path output_dir
h5_to_pb(h5_weight_path=sys.argv[1], output_dir=sys.argv[2])
来源:https://blog.csdn.net/weixin_39673686/article/details/90697587
0
投稿
猜你喜欢
- 通过python+splinter,实现在12306网站刷票并自动购票流程(无法自动识别验证码)。此类程序只是提高了12306网站的 <
- 需求:根据医保中心的文档和提供的dll动态库调用相关接口下载医保中心的账单。文档:对调用dll动态库的描述,调用哪个dll文件,同时了解清楚
- 保护你的ASP页面的两种办法 有时候你只想让人们从你的站点来访问你的某些页面, 而不允许他们从其它站点的非法链接中到达这些页面。在你想保护的
- 简单实现了一个在函数执行出现异常时自动重试的装饰器,支持控制最多重试次数,每次重试间隔,每次重试间隔时间递增。最新的代码可以访问从githu
- 最近有Win10系统用户反映,由于自己的电脑安装有两个python软件,所以想要卸载掉其中一个,不过在卸载的时候却发现无法卸载,并且出现提示
- 先说结论:变量赋值属于浅拷贝(关于深拷贝和浅拷贝的区别可以自己了解下)。故如果是可变类型变量(如a是list类型,a=b)赋值,修改a会牵连
- 本文实例讲述了Python高级编程之继承问题。分享给大家供大家参考,具体如下:多继承问题1.单独调用父类: 一个子类同时继承自多个父类,又称
- 快速回顾一下RabbitMQ服务器的安装:sudo apt-get install rabbitmq-serverPython使用Rabbi
- 对一名开发者来说最糟糕的情况,莫过于要弄清楚一个不熟悉的应用为何不工作。有时候,你甚至不知道系统运行,是否跟原始设计一致。在线运行的应用就是
- 是否曾经有过这样的经历:把一个元素置于另一个元素之上,而希望下面的那个元素成为可点击的?现在,利用css的pointer-events属性即
- 我很久前在YAHOO上扣的代码,兼容性很好,在Windows下的主流浏览器中可以正常运行。大家先不要急着下载代码,你随时都可以下,我们来分
- 前言在上一节我们通过使用NumPy的数组分割成功的在我们的图像上画了一个绿色的方块,但是如果我们想画一个单一的线条或者圆圈该怎么办呢?Num
- 在用django写项目时,遇到了许多场景,关于ORM操作获取数据的,但是不好描述出来,百度搜索关键词都不知道该怎么搜,导致一个人鼓捣了好久。
- EF Core 是一个ORM(对象关系映射),它使 .NET 开发人员可以使用 .NET对象操作数据库,避免了像ADO.NET访问数据库的代
- 去空格函数有如下两种:·LTRIM()LTRIM() 函数把字符串头部(左)的空格去掉,其语法如下:LTRIM (<character
- 很多文章都有提到关于使用phpExcel实现Excel数据的导入导出,大部分文章都差不多,或者就是转载的,都会出现一些问题,下面是本人研究p
- 验证码制作#string模块自带数字、字母、特殊字符变量集合,不需要我们手写集合import stringimport randomimpo
- python提高图像质量概述调研了一些提高图像质量的方式深度学习方法,如微软的Bringing-Old-Photos-Back-to-Lif
- 本文实例分析了CI框架出现mysql数据库连接资源无法释放的解决方法。分享给大家供大家参考,具体如下:使用ci框架提供的类查询数据:$thi
- np.nonzero函数是numpy中用于得到数组array中非零元素的位置(数组索引)的函数。一般来说,通过help(np.nonzero