tensorflow模型文件(ckpt)转pb文件的方法(不知道输出节点名)
作者:ken_asr 发布时间:2023-11-08 23:30:21
网上关于tensorflow模型文件ckpt格式转pb文件的帖子很多,本人几乎尝试了所有方法,最后终于成功了,现总结如下。方法无外乎下面两种:
使用tensorflow.python.tools.freeze_graph.freeze_graph
使用graph_util.convert_variables_to_constants
1、tensorflow模型的文件解读
使用tensorflow训练好的模型会自动保存为四个文件,如下
checkpoint:记录近几次训练好的模型结果(名称)。
xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。
xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。
xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。
2、最常见的ckpt转pb文件的方法
2、ckpt转pb文件(freeze_graph.freeze_graph)
此种方法尝试成功,虽然不知道输出节点名,但是只要模型代码还在就可以操作,直接上代码。
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from model import network # network是你们自己定义的模型结构(代码结构)
# egs:
# def network(input):
# return tf.layers.softmax(input)
model_path = "model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
def main():
tf.reset_default_graph()
# 设置输入网络的数据维度,根据训练时的模型输入数据的维度自行修改
input_node = tf.placeholder(tf.float32, shape=(None, None, 200))
output_node = network(input_node) # 神经网络的输出
# 设置输出数据类型(特别注意,这里必须要跟输出网络参数的数据格式保持一致,不然会导致模型预测 精度或者预测能力的丢失)以及重新定义输出节点的名字(这样在后面保存pb文件以及之后使用pb文件时直接使用重新定义的节点名字即可)
flow = tf.cast(output_node , tf.float16, 'the_outputs')
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, model_path)
#保存模型图(结构),为一个json文件
tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')
#将模型参数与模型图结合,并保存为pb文件
freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")
print("done")
if __name__ == '__main__':
main()
2、ckpt转pb文件(graph_util.convert_variables_to_constants)
没有成功,因为不知道输出节点的名字,使用该方法保存后的pb文件只有几十k,无法使用,写在这里主要是为了总结。直接上代码,代码里面没有的库(函数),按提示自行import。
def freeze_graph(input_checkpoint,output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
# input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "InceptionV3/Logits/SpatialSqueeze"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) #恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def,# 等于:sess.graph_def
output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
f.write(output_graph_def.SerializeToString()) #序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
# for op in graph.get_operations():
# print(op.name, op.values())
if __name__ == '__main__':
# 输入ckpt模型路径
input_checkpoint='models/model.ckpt-10000'
# 输出pb模型的路径
out_pb_path="models/pb/frozen_model.pb"
# 调用freeze_graph将ckpt转为pb
freeze_graph(input_checkpoint,out_pb_path)
参考链接:
https://www.jb51.net/article/185209.htm
https://www.jb51.net/article/185206.htm
来源:https://blog.csdn.net/zkgoup/article/details/105657577
猜你喜欢
- 本文实例为大家分享了Python+Opencv实现图像匹配功能的具体代码,供大家参考,具体内容如下1、原理简单来说,模板匹配就是拿一个模板(
- 本文实例讲述了python使用cPickle模块序列化的方法,分享给大家供大家参考。具体方法如下:import cPickledata1 =
- 空白双边距是一个极容易误解的CSS特性.它不是CSS的bug,但如果我们一旦误解,将会给你带来很多麻烦.先看如下demo代码:<!do
- 最近Google Code推出了一个面向网站开发者的 * Google DocType。它来自于网站开发者同时又面
- 下面列出了asp远程网页数据采集程序中经常用到的函数,很实用,特别是正则表达式过滤函数。包括了使用xmlhttp采集远程网页内容,使用ado
- torch.flatten(x)等于torch.flatten(x,0)默认将张量拉成一维的向量,也就是说从第一维开始平坦化,t
- 字典与json字符串区别# python 中的字典格式,是dict类型{'a': 'sd'}如果声明a =
- 微信跳一跳辅助的python具体实现代码,供大家参考,具体内容如下这是一个 2.5D 插画风格的益智游戏,玩家可以通过按压屏幕时间的长短来控
- 文件数据读写读写文件,本质上是请求操作系统打开一个文件对象,然后,通过操作系统提供的接口从这个文件对象中读取数据(读文件),或者把数据写入这
- 一、爬虫的简单理解1. 什么是爬虫?网络爬虫也叫网络蜘蛛,如果把互联网比喻成一个蜘蛛网,那么蜘蛛就是在网上爬来爬去的蜘蛛,爬虫程序通过请求u
- 在之前介绍PyQtGraph的文章中,我们都是一次性的获取数据并将其绘制为图形。然而在很多场景中,我们都需要对实时的数据进行图形化展示,比如
- batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch sizedataset.r
- 对象:是抽象的概念 如列表 元组 字典 集合 皆为对象序列化:一种方法。目的:把对象存储在磁盘上(即,将对象转换为字节数据/字符数据)。这一
- 废话不多说,直接看问题,使用过 Python 中的标准库 zipfile 解压过 zip&
- 很实用的过滤重复数据的asp代码,函数如下:<%'**************************************
- 查询一天:select * from table where to_days(column_time) = to_days(now());s
- if exists (select * from dbo.sysobjects where id = object_id(N'[db
- 本文实例分析了php字符串截取函数用法。分享给大家供大家参考。具体分析如下:php自带的截取字符串的函数只能处理英文,数字的不能截取中文混排
- 由于在Python2 中的默认编码为ASCII,但是在Python3中的默认编码为UTF-8。问题:所以在使用np.load(det.npy
- 基于python3基础课程,编写名片管理系统训练,有利于熟悉python基础代码的使用。cards_main.py#! /usr/bin/p