python深度学习TensorFlow神经网络模型的保存和读取
作者:零尾 发布时间:2022-03-18 06:49:01
之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。
TensorFlow提供了一个非常简单的API,即tf.train.Saver
类来保存和还原一个神经网络模型。
下面代码给出了保存TensorFlow模型的方法:
import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
sess.run(init_op)
print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
print("v2:", sess.run(v2))
saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
print("Model saved in file:", saver_path)
注:Saver方法已经发生了更改,现在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括号里加入该参数可继续使用V1,但会报warning,可忽略。若使用saver = tf.train.Saver()则默认使用当前的版本(V2),保存后在save这个文件夹中会出现4个文件,比V1版多出model.ckpt.data-00000-of-00001
这个文件,这点感谢评论里那位朋友指出。至于这个文件的含义到目前我仍不是很清楚,也没查到具体资料,TensorFlow15年底开源到现在很多类啊函数都一直发生着变动,或被更新或被弃用,可能一些代码在当时是没问题的,但过了一大段时间后再跑可能就会报错,在此注明事件时间:2017.4.30
这段代码中,通过saver.save
函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt"
,也就是保存到了当前程序所在文件夹里面的save
文件夹中。
TensorFlow模型会保存在后缀为.ckpt
的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。
checkpoint
文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver
类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint
文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
model.ckpt.meta
文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。
model.ckpt
文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader
类来查看model.ckpt
文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。
下面代码给出了加载TensorFlow模型的方法:
可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?
import tensorflow as tf
# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
print("v2:", sess.run(v2))
print("Model Restored")
运行结果:
v1: [[ 0.76705766 1.82217288]]
v2: [[-0.98012197 1.2369734 0.5797025 ]
[ 2.50458145 0.81897354 0.07858191]]
Model Restored
这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver
类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。
如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:
import tensorflow as tf
# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
# 直接加载持久化的图
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt")
# 通过张量的名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
运行程序,输出:
[[ 0.76705766 1.82217288]]
有时可能只需要保存或者加载部分变量。
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。
为了保存或者加载部分变量,在声明tf.train.Saver
类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])
命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。
来源:https://blog.csdn.net/lwplwf/article/details/62419087
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- 桥接模式(Bridge Pattern)是什么桥接模式是一种结构型模式,它将抽象部分与实现部分分离开来,使它们可以独立地变化。在桥接模式中,
- 我就废话不多说,直接上代码吧!from PIL import ImageGrabimport timeimport scheduleimpo
- 主程序mainaddfunc.pyfrom flask import Flask, render_template, request, ur
- 本文介绍了使用python wasmtime来访问rust库的便捷方法,步骤极其简练,可以在生产环境中使用。安装rust target wa
- 简介CSS Sprites并没有一个确定的中文翻译,通常被意译为“CSS图像拼合”或“CSS贴图定位”。CSS Sprites并不是一门新技
- 基于Python2.7的版本环境,Python实现的数据库跨服务器(跨库)迁移, 每以5000条一查询一提交,代码中可以自行更改
- python中,A object = B object 是一种赋值操作,赋的值不是一个对象在内存中的空间,而只是这个
- 检测是否注册成功<% Set Jpeg =Server.CreateObject("Persi
- 本文实例为大家分享了Django1.11自带分页器Django的具体使用方法,供大家参考,具体内容如下接下来我编写一个 views ,名cl
- python提高图像质量概述调研了一些提高图像质量的方式深度学习方法,如微软的Bringing-Old-Photos-Back-to-Lif
- 流动网页设计有很多好处,但也只有在正确使用的时候。合适的技巧会使页面在大屏幕、小屏幕抑、PDA小屏幕上都能得到良好的呈现。但是,糟糕的代码结
- 在数组中搜索一个特定值,如果找到返回TRUE否则返回FALSE boolean in_array(mixed needle,array ha
- 超酷的js图片轮换/轮播 渐变效果··来自腾讯刚刚在腾讯女性频道上看到一个很酷的图片渐变轮换效果·····于是乎····抠下来了···分享·
- 前言:随着编程语言的发展,Go 还很年轻。它于 2009 年 11 月 10 日首次发布。其创建者Robert Griesemer Rob
- 从CNNIC在2009年的报告中可以看到,超过80%的网民购物之前都要看评论(包括本站、其他站评论),超过80%的网民都比较信任口碑(包括网
- firefox不支持text-overflow一直让人很折腾。。不过还好有大虾为我们提供解决方案。。text-overflow: ellip
- 浏览带有下拉菜单的网页时,我们经常会注意到当更改显示器分辨率时,其下拉菜单的位置并没有改变,这也是我们设计网页时容易忽略的一个问题,其实通过
- 写了个JavaScript版的DateAdd、DateDiff、IsDate函数,大家评评!需要说明的是,JavaScript中IsDate
- 这里用Python逼近函数y = exp(x);同样使用泰勒函数去逼近:exp(x) = 1 + x + (x)^2/(2!) + .. +
- UCD介绍UCD是Unicode字符数据库(Unicode Character DataBase)的缩写。UCD由一些描述Unicode字符