网络编程
位置:首页>> 网络编程>> Python编程>> keras模型保存为tensorflow的二进制模型方式

keras模型保存为tensorflow的二进制模型方式

作者:Eileng  发布时间:2022-12-21 07:20:08 

标签:keras,模型保存,tensorflow,二进制

最近需要将使用keras训练的模型移植到手机上使用, 因此需要转换到tensorflow的二进制模型。

折腾一下午,终于找到一个合适的方法,废话不多说,直接上代码:


# coding=utf-8
import sys

from keras.models import load_model
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a prunned computation graph.

Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
prunned so subgraphs that are not neccesary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
      or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
 freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
 output_names = output_names or []
 output_names += [v.op.name for v in tf.global_variables()]
 input_graph_def = graph.as_graph_def()
 if clear_devices:
  for node in input_graph_def.node:
   node.device = ""
 frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
 return frozen_graph

input_fld = sys.path[0]
weight_file = 'your_model.h5'
output_graph_name = 'tensor_model.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])

from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=False)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

上面代码实现保存到当前目录的tensor_model目录下。

验证:


import tensorflow as tf
import numpy as np
import PIL.Image as Image
import cv2

def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
 output_graph_def = tf.GraphDef()

with open(pb_file_path, "rb") as f:
  output_graph_def.ParseFromString(f.read())
  tensors = tf.import_graph_def(output_graph_def, name="")
  print tensors

with tf.Session() as sess:
  init = tf.global_variables_initializer()
  sess.run(init)

op = sess.graph.get_operations()

for m in op:
   print(m.values())

input_x = sess.graph.get_tensor_by_name("convolution2d_1_input:0") #具体名称看上一段代码的input.name
  print input_x

out_softmax = sess.graph.get_tensor_by_name("activation_4/Softmax:0") #具体名称看上一段代码的output.name

print out_softmax

img = cv2.imread(jpg_path, 0)
  img_out_softmax = sess.run(out_softmax,
         feed_dict={input_x: 1.0 - np.array(img).reshape((-1,28, 28, 1)) / 255.0})

print "img_out_softmax:", img_out_softmax
  prediction_labels = np.argmax(img_out_softmax, axis=1)
  print "label:", prediction_labels

pb_path = 'tensorflow_model/constant_graph_weights.pb'
img = 'test/6/8_48.jpg'
recognize(img, pb_path)

补充知识:如何将keras训练好的模型转换成tensorflow的.pb的文件并在TensorFlow serving环境调用

首先keras训练好的模型通过自带的model.save()保存下来是 .model (.h5) 格式的文件

模型载入是通过 my_model = keras . models . load_model( filepath )

要将该模型转换为.pb 格式的TensorFlow 模型,代码如下:


# -*- coding: utf-8 -*-
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.embeddings import Embedding
from keras.layers.recurrent import LSTM
from keras.layers import Dropout
from keras.layers.wrappers import Bidirectional
from keras.models import Sequential,load_model
from keras.preprocessing import sequence
from sklearn.model_selection import train_test_split
import collections
from collections import defaultdict
import jieba
import numpy as np
import sys
reload(sys)
sys.setdefaultencoding('utf-8')
import tensorflow as tf
import os
import os.path as osp
from keras import backend as K
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
 freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
 output_names = output_names or []
 output_names += [v.op.name for v in tf.global_variables()]
 input_graph_def = graph.as_graph_def()
 if clear_devices:
  for node in input_graph_def.node:
   node.device = ""
 frozen_graph = convert_variables_to_constants(session, input_graph_def,
             output_names, freeze_var_names)
 return frozen_graph
input_fld = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/'
weight_file = 'biLSTM_brand_recognize.model'
output_graph_name = 'tensor_model_v3.pb'

output_fld = input_fld + '/tensorflow_model/'
if not os.path.isdir(output_fld):
os.mkdir(output_fld)
weight_file_path = osp.join(input_fld, weight_file)

K.set_learning_phase(0)
net_model = load_model(weight_file_path)

print('input is :', net_model.input.name)
print ('output is:', net_model.output.name)

sess = K.get_session()

frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
from tensorflow.python.framework import graph_io

graph_io.write_graph(frozen_graph, output_fld, output_graph_name, as_text=True)

print('saved the constant graph (ready for inference) at: ', osp.join(output_fld, output_graph_name))

然后模型就存成了.pb格式的文件

问题就来了,这样存下来的.pb格式的文件是frozen model

如果通过TensorFlow serving 启用模型的话,会报错:

E tensorflow_serving/core/aspired_versions_manager.cc:358] Servable {name: mnist version: 1} cannot be loaded: Not found: Could not find meta graph def matching supplied tags: { serve }. To inspect available tag-sets in the SavedModel, please use the SavedModel CLI: `saved_model_cli`

因为TensorFlow serving 希望读取的是saved model

于是需要将frozen model 转化为 saved model 格式,解决方案如下:


from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants

export_dir = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/saved_model'
graph_pb = '/data/codebase/Keyword-fenci/brand_recogniton_biLSTM/tensorflow_model/tensor_model.pb'

builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())

sigs = {}

with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
inp = g.get_tensor_by_name(net_model.input.name)
out = g.get_tensor_by_name(net_model.output.name)

sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
 tf.saved_model.signature_def_utils.predict_signature_def(
  {"in": inp}, {"out": out})

builder.add_meta_graph_and_variables(sess,
          [tag_constants.SERVING],
          signature_def_map=sigs)
builder.save()

于是保存下来的saved model 文件夹下就有两个文件:

saved_model.pb variables

其中variables 可以为空

于是将.pb 模型导入serving再读取,成功!

来源:https://blog.csdn.net/liuguangsuiyue/article/details/78298602

0
投稿

猜你喜欢

  • 目前很多软件都限制单实例,大多数软件都是用Mutex来实现的 而这个东西咱们可以用handle去干掉它,并且不影响使用。 钉钉也是一样的步骤
  • 第一招、mysql服务的启动和停止net stop mysqlnet start mysql第二招、登陆mysql语法如下: mysql -
  • 酝酿了将近一个春夏秋冬的腾讯网首页终于亮剑!反响热烈!让我们来分享它成功背后的酸甜苦辣吧。腾讯网首页改版终于开花结果。于2008年3月25日
  • 我们经常会发现网页中的许多数据并不是写死在HTML中的,而是通过js动态载入的。所以也就引出了什么是动态数据的概念,动态数据在这里指的是网页
  • 1、如何放弃正在输入的命令。 在输入一条比较长的命令时,出现打字错误是在所难免的。在这种情况下,放弃正在输入的命令重头再来往往会是更好的选择
  • 这里我推荐大家使用pycharm百度输入关键词:pycharm,点击如图所示网站进入pycharm官网选择电脑系统版本,这里我们选择Wind
  • 链接中的例子是一些脚本攻击相关的内容,有时间的朋友可以点开看看。 1.不要相信Request.QueryString: 相信在asp时代,这
  • redis-pyredis-py是Python操作Redis的第三方库,它提供了与Redis服务器交互的API。GitHub地址:https
  • MySQL支持大量的列类型,它可以被分为3类:数字类型、日期和时间类型以及字符串(字符)类型。本节首先给出可用类型的一个概述,并且总结每个列
  • 本文实例为大家分享了python点球小游戏的具体代码,供大家参考,具体内容如下1.游戏要求: 设置球的方向:左中右三个方向,射门或者扑救动作
  • 一、引言属性将值与类,结构体,枚举进行关联。Swift中的属性分为存储属性和计算属性两种,存储属性用于存储一个值,其只能用于类与结构体,计算
  • PHP是一种面向对象的编程语言,它允许开发者使用面向对象的编程技术来构建复杂的应用程序。下面是一些关于PHP面向对象编程的讲解:类与对象类是
  • asp代码 <% Dim Rs,Conn Set Conn=Server.CreateObject("Adodb.Conne
  • 前言前面在 BeanShell 里面是通过 java 脚本实现请求的预处理,jmeter里面也可以调用python的脚本,需安装 jytho
  • 很久没用sqlserver了,今天想打开sqlserver,导入数据做一下数据分析,但当我打开sqlserver工具后连接数据库后,居然报错
  • 【名称】Abs【类别】数学函数【原形】Abs(number)【参数】必选的。Number参数是一个任何有效的数值型表达式【返回值】同numb
  • CSS布局中可以用javascript判断浏览器版本看如下的javascript脚本: if (window.XMLHt
  • 我查了一下解决这个问题的办法,一般是设定全局变量,今天介绍一种新办法上代码difrouters.pyfrom flask import Fl
  • 比如:现需要向学生表中插入新的学生数据。但在插入学生数据的时,需要同时检查老师表里的数据。如果插入学生的老师不在老师表里,则先向老师表中插入
  • 在接触python时最开始接触的代码,取长方形的长和宽,定义一个长方形类,然后设置长方形的长宽属性,通过实例化的方式调用长和宽,像如下代码一
手机版 网络编程 asp之家 www.aspxhome.com