TensorFlow固化模型的实现操作
作者:Jcme丶Ls 发布时间:2022-09-12 22:28:42
前言
TensorFlow目前在移动端是无法training的,只能跑已经训练好的模型,但一般的保存方式只有单一保存参数或者graph的,如何将参数、graph同时保存呢?
生成模型
主要有两种方法生成模型,一种是通过freeze_graph把tf.train.write_graph()生成的pb文件与tf.train.saver()生成的chkp文件固化之后重新生成一个pb文件,这一种现在不太建议使用。另一种是把变量转成常量之后写入PB文件中。我们简单的介绍下freeze_graph方法。
freeze_graph
这种方法我们需要先使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件,代码如下:
with tf.Session() as sess:
saver = tf.train.Saver()
saver.save(session, "model.ckpt")
tf.train.write_graph(session.graph_def, '', 'graph.pb')
然后使用TensorFlow源码中的freeze_graph工具进行固化操作:
首先需要build freeze_graph 工具( 需要 bazel ):
bazel build tensorflow/python/tools:freeze_graph
然后使用这个工具进行固化(/path/to/表示文件路径):
bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=/path/to/graph.pb --input_checkpoint=/path/to/model.ckpt --output_node_names=output/predict --output_graph=/path/to/frozen.pb
convert_variables_to_constants
其实在TensorFlow中传统的保存模型方式是保存常量以及graph的,而我们的权重主要是变量,如果我们把训练好的权重变成常量之后再保存成PB文件,这样确实可以保存权重,就是方法有点繁琐,需要一个一个调用eval方法获取值之后赋值,再构建一个graph,把W和b赋值给新的graph。
牛逼的Google为了方便大家使用,编写了一个方法供我们快速的转换并保存。
首先我们需要引入这个方法
from tensorflow.python.framework.graph_util import convert_variables_to_constants
在想要保存的地方加入如下代码,把变量转换成常量
output_graph_def = convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output/predict'])
这里参数第一个是当前的session,第二个为graph,第三个是输出节点名(如我的输出层代码是这样的:)
with tf.name_scope('output'):
w_out = tf.Variable(w_alpha * tf.random_normal([1024, MAX_CAPTCHA * CHAR_SET_LEN]))
tf.summary.histogram('output/weight', w_out)
b_out = tf.Variable(b_alpha * tf.random_normal([MAX_CAPTCHA * CHAR_SET_LEN]))
tf.summary.histogram('output/biases', b_out)
out = tf.add(tf.matmul(dense2, w_out), b_out)
out = tf.nn.softmax(out)
predict = tf.argmax(tf.reshape(out, [-1, 11, 36]), 2, name='predict')
由于我们采用了name_scope所以我们在predict之前需要加上output/
生成文件
with tf.gfile.FastGFile('model/CTNModel.pb', mode='wb') as f:
f.write(output_graph_def.SerializeToString())
第一个参数是文件路径,第二个是指文件操作的模式,这里指的是以二进制的方式写入文件。
运行代码,系统会生成一个PB文件,接下来我们要测试下这个模型是否能够正常的读取、运行。
测试模型
在Python环境下,我们首先需要加载这个模型,代码如下:
with open('./model/rounded_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def,
input_map={'inputs/X:0': newInput_X},
return_elements=['output/predict:0'])
由于我们原本的网络输入值是一个placeholder,这里为了方便输入我们也先定义一个新的placeholder:
newInput_X = tf.placeholder(tf.float32, [None, IMAGE_HEIGHT * IMAGE_WIDTH], name="X")
在input_map的参数填入新的placeholder。
在调用我们的网络的时候直接用这个新的placeholder接收数据,如:
text_list = sesss.run(output, feed_dict={newInput_X: [captcha_image]})
然后就是运行我们的网络,看是否可以运行吧。
来源:https://www.jianshu.com/p/091415b114e2
猜你喜欢
- 核心思想在defer出现的地方插入了指令CALL runtime.deferproc,在函数返回的地方插入了CALL runtime.def
- 目录背景认识复合索引最左匹配原则字段顺序的影响复合索引可以替代单一索引吗?小结背景最近频繁出现慢SQL导致系统性能问题,于是决定针对索引进行
- 下面为您介绍sql下用了判断各种资源是否存在的代码,需要的朋友可以参考下,希望对您学习sql的函数及数据库能够有所帮助。-- 库是否存在if
- 说明1、导入模块pyplot,并指定别名plt,以避免重复输入pyplot。模块化pyplot包含许多用于制作图表的功能。2、将绘制的直线坐
- SELECTSELECT 语句用于从表中选取数据,是 SQL 最基本的操作之一。通过 SELECT 查询的结果被存储在一个结果表中(称为结果
- 当我们使用访问一个没有声明的变量时,JS会报错;而当我们给一个没有声明的变量赋值时,JS不会报错,相反它会认为我们是要隐式申明一个全局变量。
- 今天突然有同事问起,如何在sqlserver中调试存储过程(我们公司使用的是sqlserver 2008 R2),猛地一看,和以前使用sql
- PyQt5信号与槽高级自定义信号与槽所谓高级自定义信号与槽,指的就是我们可以以自己喜欢的方式定义信号与槽函数,并传递参数,自定义信号的一般流
- 我就废话不多说了,大家还是直接看代码吧!print("thresh =",thresh)coords = np.colu
- 代码如下:create proc p_sword_getblcolumn ( @tblN
- 反射反射即想到4个内置函数分别为:getattr、hasattr、setattr、delattr 获取成员、检查成员、设置成员、
- 1、安装flask_sqlalchemy和pymysql包pip install flask-sqlalchemypip install p
- torch.Tensor类型的数据loss和acc打印时如果写成以下写法print('batch_loss: '+str(l
- 不知道您是否留意了,浏览本站时,浏览器右下角有一个标着top的黑色直角三角形,可以点击它返回到正在浏览的网页页眉。当滚动网页时,它的位置一直
- 在Visual Studio 中使用git——什么是Git(一)如果要使用git进行版本管理,其实使用git命令行工具就完全足够了,图形化工
- 在做数据分析或者统计的时候,经常需要进行数据正态性的检验,因为很多假设都是基于正态分布的基础之上的,例如:T检验。在Python中,主要有以
- Flask框架介绍Flask诞生于2010年,是Armin ronacher用Python语言基于Werkzeug工具箱编写的轻量级Web开
- 一、安装相关的模块首先第一步的话我们需要安装相关的模块,通过pip命令来安装pip install gif另外由于gif模块之后会被当做是装
- Numpy中的N维数组(ndarray)Numpy 中的数组是一个元素表(通常是数字),所有元素类型相同,由正整数元组索引。在 Numpy
- 序篇天气真的很热啊… 很想有一杯冰冰凉凉的奶茶来解渴~但是现在奶茶店这么多, 到底哪一家最好喝、性价比最高呢?数据获取