tensorflow使用freeze_graph.py将ckpt转为pb文件的方法
作者:yjl9122 发布时间:2023-01-31 15:31:05
废话少说直接上代码样例如下
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
# 本来这个model本无需解释太多,但是这么多人不能耐下心来看,那么我简单的说一下吧
# network是你们自己定义的模型结构而已
# ps:
# def network(input):
# return tf.layers.max_pooling2d(input, 2, 2)
from model import network
os.environ['CUDA_VISIBLE_DEVICES']='2' #设置GPU
model_path = "path to /model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
def main():
tf.reset_default_graph()
input_node = tf.placeholder(tf.float32, shape=(228, 304, 3)) #这个是你送入网络的图片大小,如果你是其他的大小自行修改
input_node = tf.expand_dims(input_node, 0)
flow = network(input_node)
flow = tf.cast(flow, tf.uint8, 'out') #设置输出类型以及输出的接口名字,为了之后的调用pb的时候使用
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
#保存图
tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
#把图和参数结构一起
freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'out','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
print("done")
if __name__ == '__main__':
main()
这节是关于tensorflow的Freezing,字面意思是冷冻,可理解为整合合并;整合什么呢,就是将模型文件和权重文件整合合并为一个文件,主要用途是便于发布。
官方解释可参考:https://www.tensorflow.org/extend/tool_developers/#freezing
这里我按我的理解翻译下,不对的地方请指正:
有一点令我们为比较困惑的是,tensorflow在训练过程中,通常不会将权重数据保存的格式文件里(这里我理解是模型文件),反而是分开保存在一个叫checkpoint的检查点文件里,当初始化时,再通过模型文件里的变量Op节点来从checkoupoint文件读取数据并初始化变量。这种模型和权重数据分开保存的情况,使得发布产品时不是那么方便,所以便有了freeze_graph.py脚本文件用来将这两文件整合合并成一个文件。
freeze_graph.py是怎么做的呢?首行它先加载模型文件,再从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重 常量 (因为 常量 能随模型一起保存在同一个文件里),然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)
文件目录:tensorflow/python/tools/free_graph.py
测试文件:tensorflow/python/tools/free_graph_test.py 这个测试文件很有学习价值
参数:
总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):
1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)
2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。
3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False
4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。
5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。
6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all
7、filename_tensor_name:(可选)已弃用。默认:save/Const:0
8、output_graph:(必选)用来保存整合后的模型输出文件。
9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)
10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。
11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。
用法:
例:python tensorflow/python/tools/free_graph.py \
–input_graph=some_graph_def.pb \ 注意:这里的pb文件是用tf.train.write_graph方法保存的
–input_checkpoint=model.ckpt.1001 \ 注意:这里若是r12以上的版本,只需给.data-00000….前面的文件名,如:model.ckpt.1001.data-00000-of-00001,只需写model.ckpt.1001
–output_graph=/tmp/frozen_graph.pb
–output_node_names=softmax
另外,如果模型文件是.meta格式的,也就是说用saver.Save方法和checkpoint一起生成的元模型文件,free_graph.py不适用,但可以改造下:
1、copy free_graph.py为free_graph_meta.py
2、修改free_graph.py,导入meta_graph:from tensorflow.python.framework import meta_graph
3、将91行到97行换成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def
这样改即可加载meta文件
来源:https://blog.csdn.net/yjl9122/article/details/78341689
猜你喜欢
- vue3 表单验证前言表单验证可以有效的过滤不合格的数据,减少服务器的开销,并提升用户的使用体验。今天我们使用 vue3 来做一个表单验证的
- 步骤:1.从php.net上面下载php5.3.x版本的源码;2.centos安装相应的扩展包:yum install libmcrypt
- pytorch常用函数torch.randn()torch.randn(*sizes, out=None) → Tensor功能:从标准正态
- 数据备份与还原第三篇,具体如下基础概念:备份,将当前已有的数据或记录另存一份;还原,将数据恢复到备份时的状态。为什么要进行数据的备份与还原?
- 前言在设计爬虫项目的时候,首先要在脑内明确人工浏览页面获得图片时的步骤一般地,我们去网上批量打开壁纸的时候一般操作如下:1、打开壁纸网页2、
- 这篇文章主要介绍了python matplotlib拟合直线的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 当服务器必须提供与两个或更多个网络或网络子网的连接时,典型的方案是使用多宿主计算机。此计算机通常位于外围网络(也称为 DMZ、外围安全区域或
- 本文整理了3种鼠标经过图片,图片边框加粗或改变颜色的方法,希望大家喜欢。下面3中只是提供了一个方法,具体的鼠标经过图片的样式,你自己可以修改
- 前言前几天下载安装了最新版的MySQL 8.0.22,遇到了不少问题,参考了一些方法,最终得以解决。今天将自己的安装过程记录下来,希望对各位
- 这篇文章主要介绍了用Python画一个LinkinPark的logo代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的
- 话不多说,先上效果,一个体验非常好的拖拽缓动的效果,让页面提升一个档次。这个效果看似很简单,到也困惑了很长时间,为什么别人写出来的拖拽体验为
- 该模块主要功能是提供可存储cookie的对象。使用此模块捕获cookie并在后续连接请求时重新发送,还可以用来处理包含cookie数据的文件
- 一、进程介绍进程:正在执行的程序,由程序、数据和进程控制块组成,是正在执行的程序,程序的一次执行过程,是资源调度的基本单位。程序:没有执行的
- 问题你想定义一个函数或者方法,它的一个或多个参数是可选的并且有一个默认值。解决方案定义一个有可选参数的函数是非常简单的,直接在函数定义中给参
- 数学函数 1.绝对值 S:select abs(-1) value O:select abs(-1) value from dual 2.取
- 题目:有四个数字:1、2、3、4,能组成多少个互不相同且无重复数字的三位数?各是多少?程序分析:可填在百位、十位、个位的数字都是1、2、3、
- 一.wget https://dev.mysql.com/get/mysql57-community-release-el7-11.noar
- MySQL安全性指南(3) 作 者: 晏子2.4 不用GRANT设置用户如果你有一个早于3.22.11的MySQL版本,你不能使用GRANT
- 前言在前几天的文章中我们讲解了如何从Word表格中提取指定数据并按照格式保存到Excel中,今天我们将再次以一位读者提出的真实需求来讲解如何
- 1. 欧几里德算法欧几里德算法又称辗转相除法, 用于计算两个整数a, b的最大公约数。其计算原理依赖于下面的定理:定理: gcd(a, b)