TensorFlow 模型载入方法汇总(小结)
作者:叠加态的猫 发布时间:2022-11-09 00:05:42
一、TensorFlow常规模型加载方法
保存模型
tf.train.Saver()类,.save(sess, ckpt文件目录)方法
参数名称 | 功能说明 | 默认值 |
var_list | Saver中存储变量集合 | 全局变量集合 |
reshape | 加载时是否恢复变量形状 | True |
sharded | 是否将变量轮循放在所有设备上 | True |
max_to_keep | 保留最近检查点个数 | 5 |
restore_sequentially | 是否按顺序恢复变量,模型较大时顺序恢复内存消耗小 | True |
var_list是字典形式{变量名字符串: 变量符号},相对应的restore也根据同样形式的字典将ckpt中的字符串对应的变量加载给程序中的符号。
如果Saver给定了字典作为加载方式,则按照字典来,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否则每个变量寻找自己的name属性在ckpt中的对应值进行加载。
加载模型
当我们基于checkpoint文件(ckpt)加载参数时,实际上我们使用Saver.restore取代了initializer的初始化
checkpoint文件会记录保存信息,通过它可以定位最新保存的模型:
ckpt = tf.train.get_checkpoint_state('./model/')
print(ckpt.model_checkpoint_path)
![]()
.meta文件保存了当前图结构
.index文件保存了当前参数名
.data文件保存了当前参数值
tf.train.import_meta_graph函数给出model.ckpt-n.meta的路径后会加载图结构,并返回saver对象
ckpt = tf.train.get_checkpoint_state('./model/')
tf.train.Saver函数会返回加载默认图的saver对象,saver对象初始化时可以指定变量映射方式,根据名字映射变量(『TensorFlow』滑动平均)
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
saver.restore函数给出model.ckpt-n的路径后会自动寻找参数名-值文件进行加载
saver.restore(sess,'./model/model.ckpt-0')
saver.restore(sess,ckpt.model_checkpoint_path)
1.不加载图结构,只加载参数
由于实际上我们参数保存的都是Variable变量的值,所以其他的参数值(例如batch_size)等,我们在restore时可能希望修改,但是图结构在train时一般就已经确定了,所以我们可以使用tf.Graph().as_default()新建一个默认图(建议使用上下文环境),利用这个新图修改和变量无关的参值大小,从而达到目的。
'''
使用原网络保存的模型加载到自己重新定义的图上
可以使用python变量名加载模型,也可以使用节点名
'''
import AlexNet as Net
import AlexNet_train as train
import random
import tensorflow as tf
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3])
y = Net.inference_1(x, N_CLASS=5, train=False)
with tf.Session() as sess:
# 程序前面得有 Variable 供 save or restore 才不报错
# 否则会提示没有可保存的变量
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('./model/')
img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0))
if ckpt and ckpt.model_checkpoint_path:
print(ckpt.model_checkpoint_path)
saver.restore(sess,'./model/model.ckpt-0')
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
res = sess.run(y, feed_dict={x: img})
print(global_step,sess.run(tf.argmax(res,1)))
2.加载图结构和参数
'''
直接使用使用保存好的图
无需加载python定义的结构,直接使用节点名称加载模型
由于节点形状已经定下来了,所以有不便之处,placeholder定义batch后单张传会报错
现阶段不推荐使用,以后如果理解深入了可能会找到使用方法
'''
import AlexNet_train as train
import random
import tensorflow as tf
IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg'
ckpt = tf.train.get_checkpoint_state('./model/') # 通过检查点文件锁定最新的模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 载入图结构,保存在.meta文件中
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path) # 载入参数,参数保存在两个文件中,不过restore会自己寻找
img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read()
img = sess.run(tf.image.resize_images(
tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)))
imgs = []
for i in range(128):
imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs}))
'''
img = sess.run(tf.expand_dims(tf.image.resize_images(
tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0))
print(img)
imgs = []
for i in range(128):
imgs.append(img)
print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'),
feed_dict={'Placeholder:0':img}))
注意,在所有两种方式中都可以通过调用节点名称使用节点输出张量,节点.name属性返回节点名称。
3.简化版本
# 连同图结构一同加载
ckpt = tf.train.get_checkpoint_state('./model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta')
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path)
# 只加载数据,不加载图结构,可以在新图中改变batch_size等的值
# 不过需要注意,Saver对象实例化之前需要定义好新的图结构,否则会报错
saver = tf.train.Saver()
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('./model/')
saver.restore(sess,ckpt.model_checkpoint_path)
二、TensorFlow二进制模型加载方法
这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作
# 新建空白图
self.graph = tf.Graph()
# 空白图列为默认图
with self.graph.as_default():
# 二进制读取模型文件
with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f:
# 新建GraphDef文件,用于临时载入模型中的图
graph_def = tf.GraphDef()
# GraphDef加载模型中的图
graph_def.ParseFromString(f.read())
# 在空白图中加载GraphDef中的图
tf.import_graph_def(graph_def,name='')
# 在图中获取张量需要使用graph.get_tensor_by_name加张量名
# 这里的张量可以直接用于session的run方法求值了
# 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name)
self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]
来源:https://www.cnblogs.com/hellcat/p/6925757.html


猜你喜欢
- 之前安装mysql时未做总结,换新电脑,补上安装记录,安装的时候,找了些网友的安装记录,发现好多坑1、mysql-5.7.12-winx64
- 一、mysqlcheck简介mysqlcheck客户端可以检查和修复MyISAM表。它还可以优化和分析表。mysqlcheck的功能类似my
- 异常可以防止出现一些不友好的信息返回给用户,有助于提升程序的可用性,在java中通过try ... catch ... finally来处理
- 简介🤔看了一圈,大家对 ts 封装 axios 都各有见解。但都不是我满意的吧,所以自己封装了一个💪。至于为什么敢叫最佳实践,因为我满意,就
- 一、前言作为一个数据库爱好者,自己动手写过简单的SQL解析器以及存储引擎,但感觉还是不够过瘾。<<事务处理-概念与技术>&
- 问题描述MySQL函数或者存储过程中使用group_concat()函数导致数据字符过长而报错CREATE DEFINER=`root`@`
- 前几天小芳同学一直在群发起一些加速的话题,我已经把聊天记录抽出来,正打算整理出份像样的,没想到小芳同学非常速度的出了这篇。我的就省掉了,挖哈
- 1.函数对象前面我们学习了关于Python中的变量类型,例如int、str、bool、list等等…&hell
- 脉冲星假信号频率的相对路径论证。首先看一下演示结果:实例代码:import numpy as npimport matplotlib.pyp
- 一、为什么要搭建爬虫代理池在众多的网站防爬措施中,有一种是根据ip的访问频率进行限制,即在某一时间段内,当某个ip的访问次数达到一定的阀值时
- 一:分组函数的语句顺序 1 SELECT ... 2 FROM ...
- 本文实例讲述了php测试kafka项目。分享给大家供大家参考,具体如下:概述Kafka是最初由Linkedin公司开发,是一个分布式、分区的
- 一、数据备份1、使用mysqldump命令备份mysqldump命令将数据库中的数据备份成一个文本文件。表的结构和表中的数据将存储在生成的文
- argparse 是python自带的命令行参数解析包,可以用来方便地读取命令行参数。一、传入一个参数import argpars
- 在使用Python做开发的时候,时不时会给自己编写了一些小工具辅助自己的工作,但是由于开发依赖环境问题,多数只能在自己电脑上运行,拿到其它电
- 之前用小程序做项目,因为后台使用的java开发,一切顺利,但切换成django做RESTful API接口时,在登陆注册时一直出现问题,网上
- 有时候在网上办理一些业务时有些需要填写银行卡号码,当胡乱填写时会立即报错,但是并没有发现向后端发送请求,那么这个效果是怎么实现的呢。对于银行
- 前言使用的pyecharts是v1.0这里需要注意,pyecharts0.5的版本和v1.0以上的版本完全不一样,可以说是两个包该包能够方便
- 字符串在 Python 中创建字符串对象非常容易。只要将所需的文本放入一对引号中,就完成了一个新字符串的创建(参见清单 1)。如果稍加思考的
- 前言地图定位这个功能大家都很熟悉吧,那微信小程序中要怎么实现地图定位呢,其实非常简单,没有大家想象中那么难,看完本篇文章,你也可以轻松实现这