Tensorflow 使用pb文件保存(恢复)模型计算图和参数实例详解
作者:brooknew 发布时间:2023-12-13 05:42:06
标签:Tensorflow,pb,保存模型,计算图
一、保存:
graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量
参数2:GraphDef 对象,它描述了计算网络
参数3:Graph图中需要输出的节点的名称的列表
返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:
constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())
需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。
二、恢复:
恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:
graph0 = tf.GraphDef()
with open( pbName, mode='rb') as f:
graph0.ParseFromString( f.read() )
tf.import_graph_def( graph0 , name = '' )
三、代码:
import tensorflow as tf
from tensorflow.python.framework import graph_util
pbName = 'graphA.pb'
def graphCreate() :
with tf.Session() as sess :
var1 = tf.placeholder ( tf.int32 , name='var1' )
var2 = tf.Variable( 20 , name='var2' )#实参name='var2'指定了操作名,该操作返回的张量名是在
#'var2'后面:0 ,即var2:0 是返回的张量名,也就是说变量
# var2的名称是'var2:0'
var3 = tf.Variable( 30 , name='var3' )
var4 = tf.Variable( 40 , name='var4' )
var4op = tf.assign( var4 , 1000 , name = 'var4op1' )
sum = tf.Variable( 4, name='sum' )
sum = tf.add ( var1 , var2, name = 'var1_var2' )
sum = tf.add( sum , var3 , name='sum_var3' )
sumOps = tf.add( sum , var4 , name='sum_operation' )
oper = tf.get_default_graph().get_operations()
with open( 'operation.csv','wt' ) as f:
s = 'name,type,output\n'
f.write( s )
for o in oper:
s = o.name
s += ','+ o.type
inp = o.inputs
oup = o.outputs
for iip in inp :
s #s += ','+ str(iip)
for iop in oup :
s += ',' + str(iop)
s += '\n'
f.write( s )
for var in tf.global_variables():
print('variable=> ' , var.name) #张量是tf.Variable/tf.Add之类操作的结果,
#张量的名字使用操作名加:0来表示
init = tf.global_variables_initializer()
sess.run( init )
sess.run( var4op )
print('sum_operation result is Tensor ' , sess.run( sumOps , feed_dict={var1:1}) )
constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())
def graphGet() :
print("start get:" )
with tf.Graph().as_default():
graph0 = tf.GraphDef()
with open( pbName, mode='rb') as f:
graph0.ParseFromString( f.read() )
tf.import_graph_def( graph0 , name = '' )
with tf.Session() as sess :
init = tf.global_variables_initializer()
sess.run(init)
v1 = sess.graph.get_tensor_by_name('var1:0' )
v2 = sess.graph.get_tensor_by_name('var2:0' )
v3 = sess.graph.get_tensor_by_name('var3:0' )
v4 = sess.graph.get_tensor_by_name('var4:0' )
sumTensor = sess.graph.get_tensor_by_name("sum_operation:0")
print('sumTensor is : ' , sumTensor )
print( sess.run( sumTensor , feed_dict={v1:1} ) )
graphCreate()
graphGet()
四、保存pb函数代码里的操作名称/类型/返回的张量:
operation name | operation type | output | ||
var1 | Placeholder | Tensor("var1:0" | dtype=int32) | |
var2/initial_value | Const | Tensor("var2/initial_value:0" | shape=() | dtype=int32) |
var2 | VariableV2 | Tensor("var2:0" | shape=() | dtype=int32_ref) |
var2/Assign | Assign | Tensor("var2/Assign:0" | shape=() | dtype=int32_ref) |
var2/read | Identity | Tensor("var2/read:0" | shape=() | dtype=int32) |
var3/initial_value | Const | Tensor("var3/initial_value:0" | shape=() | dtype=int32) |
var3 | VariableV2 | Tensor("var3:0" | shape=() | dtype=int32_ref) |
var3/Assign | Assign | Tensor("var3/Assign:0" | shape=() | dtype=int32_ref) |
var3/read | Identity | Tensor("var3/read:0" | shape=() | dtype=int32) |
var4/initial_value | Const | Tensor("var4/initial_value:0" | shape=() | dtype=int32) |
var4 | VariableV2 | Tensor("var4:0" | shape=() | dtype=int32_ref) |
var4/Assign | Assign | Tensor("var4/Assign:0" | shape=() | dtype=int32_ref) |
var4/read | Identity | Tensor("var4/read:0" | shape=() | dtype=int32) |
var4op1/value | Const | Tensor("var4op1/value:0" | shape=() | dtype=int32) |
var4op1 | Assign | Tensor("var4op1:0" | shape=() | dtype=int32_ref) |
sum/initial_value | Const | Tensor("sum/initial_value:0" | shape=() | dtype=int32) |
sum | VariableV2 | Tensor("sum:0" | shape=() | dtype=int32_ref) |
sum/Assign | Assign | Tensor("sum/Assign:0" | shape=() | dtype=int32_ref) |
sum/read | Identity | Tensor("sum/read:0" | shape=() | dtype=int32) |
var1_var2 | Add | Tensor("var1_var2:0" | dtype=int32) | |
sum_var3 | Add | Tensor("sum_var3:0" | dtype=int32) | |
sum_operation | Add | Tensor("sum_operation:0" | dtype=int32) |
来源:https://blog.csdn.net/brooknew/article/details/83063512
0
投稿
猜你喜欢
- 在做语义分割项目时,标注的图片不合标准,而且类型是RGBA型,且是A的部分表示的类别,因此需要将该图片转化为RGB图片# -*- codin
- 这次哀悼,网页设计方面除了应用CSS灰度配色和滤镜,还用到正计时代码,就象汶川大地震已过去了多少天。下面这段代码,是从网易页面提取出来的,具
- 测试环境 硬件:CPU 酷睿双核T5750 内存:2G 软件:Windows server 2003 + sql server 2005 O
- 大家都知道在Dreamwerver中可以很方便地实现记录集的分页显示,但是生成的代码的确很庞大,影响了网页的显示速度,看起来条理也不是很清晰
- 最近刚刚接触深度学习,并尝试学习制作数据集,制作过程中发现了一个问题,现在跟大家分享一下。问题是这样的,在制作voc数据集时,我采集的是灰度
- 自己写了一下,适用而已,不太好,应该还能优化。先自己记录一下。不说废话了,直接贴代码最好:/* * 获得时间差,时间格式为 年-月
- Python实现Mysql数据统计的实例代码如下所示:import pymysqlimport xlwtexcel=xlwt.Workboo
- 今天发现有一个程序插入的时间不对,而该字段是配置的默认值 CURRENT_TIMESTAMP,初步判断是数据库的时区设置问题。查看时区登录数
- 前言本人曾对 Vuex 作过详细介绍,但是今天去回顾的时候发现文章思路有些繁琐,不容易找到重点。于是,在下班前几分钟,我对其重新梳理了一遍。
- 前言在文档对象模型 (DOM) 中,每个节点都是一个对象。DOM 节点有三个重要的属性 :1. nodeName : 节点的名称2. nod
- Flask是一个轻量级的Web框架。虽然是轻量级的,但是对于组件一个大型的、模块化应用也是能够实现的,“蓝图”就是这样一种实现。对于模块化应
- 本文实例为大家分享了python傅里叶变换FFT绘制频谱图的具体代码,供大家参考,具体内容如下频谱图的横轴表示的是 频率, 纵轴表
- Usuage: go run kNN.go --file="data.txt"关键是向量点的选择和阈值的判定
- 这篇文章主要介绍了Python matplotlib画曲线例题解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 有三种方法,一是用微软提供的扩展库win32com来操作IE,二是用selenium的webdriver,三是用python自带的HTMLP
- 在一个网页中,不仅仅只有一个html骨架,还需要css样式文件,js执行文件以及一些图片等。因此在DTL中加载静态文件是一个必须要解决的问题
- 在我的印象里面进制互相转换确实是很常见的问题,所以在Python中,自然也少不了把下面这些代码收为util。这是从网上搜索的一篇也的还可以的
- 前一段时间有发过一个简单的JMAIL邮件发邮件的代码,今天就把这个代码做一个具体的注解,并增加了另外两个格式的代码,并举几个简单
- 一.灰度线性变换图像的灰度线性变换是通过建立灰度映射来调整原始图像的灰度,从而改善图像的质量,凸显图像的细节,提高图像的对比度。灰度线性变换
- 如下所示:import osdef anyTrue(predicate, sequence):return True in map(pred