Keras模型转成tensorflow的.pb操作
作者:VickyD1023 发布时间:2023-12-22 13:10:34
标签:Keras,tensorflow,.pb
Keras的.h5模型转成tensorflow的.pb格式模型,方便后期的前端部署。直接上代码
from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.mobilenet import MobileNet
from keras.applications.mobilenet import preprocess_input
from keras.preprocessing.image import load_img, img_to_array
import tensorflow as tf
from keras import backend as K
import os
base_model = MobileNet((None, None, 3), alpha=1, include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(base_model.output)
x = Dense(10, activation='softmax')(x)
model = Model(base_model.input, x)
model.load_weights('mobilenet_weights.h5')
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
output_graph_name = 'NIMA.pb'
output_fld = ''
#K.set_learning_phase(0)
print('input is :', model.input.name)
print ('output is:', model.output.name)
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[model.output.op.name])
from tensorflow.python.framework import graph_io
graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)
print('saved the constant graph (ready for inference) at: ', os.path.join(output_fld, output_graph_name))
补充知识:keras h5 model 转换为tflite
在移动端的模型,若选择tensorflow或者keras最基本的就是生成tflite文件,以本文记录一次转换过程。
环境
tensorflow 1.12.0
python 3.6.5
h5 model saved by `model.save('tf.h5')`
直接转换
`tflite_convert --output_file=tf.tflite --keras_model_file=tf.h5`
output
`TypeError: __init__() missing 2 required positional arguments: 'filters' and 'kernel_size'`
先转成pb再转tflite
```
git clone git@github.com:amir-abdi/keras_to_tensorflow.git
cd keras_to_tensorflow
python keras_to_tensorflow.py --input_model=path/to/tf.h5 --output_model=path/to/tf.pb
tflite_convert \
--output_file=tf.tflite \
--graph_def_file=tf.pb \
--input_arrays=convolution2d_1_input \
--output_arrays=dense_3/BiasAdd \
--input_shape=1,3,448,448
```
参数说明,input_arrays和output_arrays是model的起始输入变量名和结束变量名,input_shape是和input_arrays对应
官网是说需要用到tenorboard来查看,一个比较trick的方法
先执行上面的命令,会报convolution2d_1_input找不到,在堆栈里面有convert_saved_model.py文件,get_tensors_from_tensor_names()这个方法,添加`print(list(tensor_name_to_tensor))` 到 tensor_name_to_tensor 这个变量下面,再执行一遍,会打印出所有tensor的名字,再根据自己的模型很容易就能判断出实际的name。
来源:https://blog.csdn.net/q6324266/article/details/85262438
0
投稿
猜你喜欢
- 在数据分析中经常需要从csv格式的文件中存取数据以及将数据写书到csv文件中。将csv文件中的数据直接读取为 dict 类型和 DataFr
- 本文侧重于如何使用Python语言实现SIFT算法所有程序已打包:基于OpenCV-Python的SIFT算法的实现一、什么是SIFT算法
- 与抓取预定义好的页面集合不同,抓取一个网站的所有内链会带来一个 挑战,即你不知道会获得什么。好在有几种基本的方法可以识别页面类型。通过URL
- 什么是pyecharts?pyecharts 是一个用于生成 Echarts 图表的类库。echarts是百度开源的一个数据可视化 JS 库
- 自动上次ymPrompt组件发布,自己就曾发现在IE8下遮罩的半透明滤镜有时无效的问题,后来也有网友提出过这个问题,但自己一直也没有太多关注
- 每个产品诞生的背后都凝结着一位或是多位设计师的心血,在产品的诞生过程中文化、科技、环保、创意等这些方方面面的细节集结成一个绚丽的故事,因为有
- 概述🌱记住日期是有点困难,但我们是程序员,使困难的事情更容易是我们唯一的工作,所以我们不记得日期为什么不自动化这个任务。在这篇文章中,我们将
- 1. 背景最近在爬取某个站点时,发现在POST数据时,使用的数据格式是request payload,有别于之前常见的 POST数据格式(F
- 本文实例讲述了Python实现测试磁盘性能的方法。分享给大家供大家参考。具体如下:该代码做了如下工作:create 300000 files
- Pygame的mixer 模块可以依据命令播放一个或多个声音,并且也可以将这些声音混合在一起。而获得声音需要四个步骤:一、启动mixer进程
- <script language="javascript"> function disableRightCl
- 1、引言小 * 丝:鱼哥,你说咱们发快递时填写的地址信息,到后台怎么能看清楚写的对不对呢?小鱼:这种事情还要问? 你没在电商行业混过??小 * 丝:
- 1.数据的容量:1-3年内会大概多少条数据,每条数据大概多少字节; 2.数据项:是否有大字段,那些字段的值是否经常被更新; 3.数据查询SQ
- 本文实例讲述了Python GUI编程学习笔记之tkinter中messagebox、filedialog控件用法。分享给大家供大家参考,具
- pytest-playwright 是一个 Python 包,它允许您使用 Microsoft 的 Playwright 库在 Python
- <html> <body> &nbs
- vue配置element-ui遇到的坑注意:本文章参照element-ui官方文档,快速上手部分,的部分教程步骤1.npm安装npm i e
- 1、创建mysite测试站点:django-admin.py startproject mysite 2、创建测试页:hello.py,内容
- 问题怎样实现一个按优先级排序的队列? 并且在这个队列上面每次 pop 操作总是返回优先级最高的那个元素解决方案下面的类利用 heapq 模块
- 在多数的现代语音识别系统中,人们都会用到频域特征。梅尔频率倒谱系数(MFCC),首先计算信号的功率谱,然后用滤波器和离散余弦变换的变换来提取