python深度学习TensorFlow神经网络模型的保存和读取
作者:零尾 发布时间:2022-03-18 06:49:01
之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。
TensorFlow提供了一个非常简单的API,即tf.train.Saver
类来保存和还原一个神经网络模型。
下面代码给出了保存TensorFlow模型的方法:
import tensorflow as tf
# 声明两个变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部变量
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
sess.run(init_op)
print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比
print("v2:", sess.run(v2))
saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件
print("Model saved in file:", saver_path)
注:Saver方法已经发生了更改,现在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括号里加入该参数可继续使用V1,但会报warning,可忽略。若使用saver = tf.train.Saver()则默认使用当前的版本(V2),保存后在save这个文件夹中会出现4个文件,比V1版多出model.ckpt.data-00000-of-00001
这个文件,这点感谢评论里那位朋友指出。至于这个文件的含义到目前我仍不是很清楚,也没查到具体资料,TensorFlow15年底开源到现在很多类啊函数都一直发生着变动,或被更新或被弃用,可能一些代码在当时是没问题的,但过了一大段时间后再跑可能就会报错,在此注明事件时间:2017.4.30
这段代码中,通过saver.save
函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt"
,也就是保存到了当前程序所在文件夹里面的save
文件夹中。
TensorFlow模型会保存在后缀为.ckpt
的文件中。保存后在save这个文件夹中会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。
checkpoint
文件保存了一个目录下所有的模型文件列表,这个文件是tf.train.Saver
类自动生成且自动维护的。在 checkpoint文件中维护了由一个tf.train.Saver类持久化的所有TensorFlow模型文件的文件名。当某个保存的TensorFlow模型文件被删除时,这个模型所对应的文件名也会从checkpoint
文件中删除。checkpoint中内容的格式为CheckpointState Protocol Buffer.
model.ckpt.meta
文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
TensorFlow通过元图(MetaGraph)来记录计算图中节点的信息以及运行计算图中节点所需要的元数据。TensorFlow中元图是由MetaGraphDef Protocol Buffer定义的。MetaGraphDef 中的内容构成了TensorFlow持久化时的第一个文件。保存MetaGraphDef 信息的文件默认以.meta为后缀名,文件model.ckpt.meta中存储的就是元图数据。
model.ckpt
文件保存了TensorFlow程序中每一个变量的取值,这个文件是通过SSTable格式存储的,可以大致理解为就是一个(key,value)列表。model.ckpt文件中列表的第一行描述了文件的元信息,比如在这个文件中存储的变量列表。列表剩下的每一行保存了一个变量的片段,变量片段的信息是通过SavedSlice Protocol Buffer定义的。SavedSlice类型中保存了变量的名称、当前片段的信息以及变量取值。TensorFlow提供了tf.train.NewCheckpointReader
类来查看model.ckpt
文件中保存的变量信息。如何使用tf.train.NewCheckpointReader类这里不做说明,自查。
下面代码给出了加载TensorFlow模型的方法:
可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?
import tensorflow as tf
# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来
print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比
print("v2:", sess.run(v2))
print("Model Restored")
运行结果:
v1: [[ 0.76705766 1.82217288]]
v2: [[-0.98012197 1.2369734 0.5797025 ]
[ 2.50458145 0.81897354 0.07858191]]
Model Restored
这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver
类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。
也就是说使用TensorFlow完成了一次模型的保存和读取的操作。
如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:
import tensorflow as tf
# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量
# 直接加载持久化的图
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt")
# 通过张量的名称来获取张量
print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
运行程序,输出:
[[ 0.76705766 1.82217288]]
有时可能只需要保存或者加载部分变量。
比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。
为了保存或者加载部分变量,在声明tf.train.Saver
类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])
命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。
来源:https://blog.csdn.net/lwplwf/article/details/62419087


猜你喜欢
- SOAP.py 客户机和服务器SOAP.py 包含的是一些基本的东西。没有 Web 服务描述语言(Web Services Descript
- 1.在百度地图申请密钥: http://lbsyun.baidu.com/ 将<script type="tex
- Object.freeze()Object.freeze() 方法可以冻结一个对象。一个被冻结的对象再也不能被修改;冻结了一个对象则不能向这
- 一、安装软件包并创建项目$sudo pip install django$sudo python -c "import djang
- 本文实例讲述了Python smallseg分词用法。分享给大家供大家参考。具体分析如下:#encoding=utf-8 #import p
- 本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理。以下文章来源于Pyth
- 1. 实例描述在平时编程的过程中,会经常在网上翻译一些单词,本文使用Python制作一款翻译小工具,不仅可以自己用,还可以嵌入到程序当中。运
- 本文研究的主要是Python编程argparse的相关内容,具体介绍如下。#aaa.py#version 3.5import os &nbs
- 应用 MySQL 时,会遇到不能创建函数的情况。出现如下错误信息:ERROR 1418 : This function has none o
- 一、出错情况 有些时候当你重启了数据库服务,会发现有些数据库变成了正在恢复、置疑、可疑等情况,这个时候DBA就会很紧张了,下面是一些在实践中
- 我使用Pytorch进行模型训练时发现真正模型本身对于显存的占用并不明显,但是对应的转换为tensorflow后(权重也进行了转换),发现P
- asp之家注:有时候我们想让程序运行变慢下来,asp中该怎么做呢?原理很简单就是在运行程序前运行一段无关紧要的程序就可以了,要实现加长程序的
- 1、settings.INSTALLED_APPS下添加:django.contrib.staticfiles2、settings.py下添
- Python 内置的四种常用数据结构:列表(list)、元组(tuple)、字典(dict)以及集合(set)。这四种数据结构一但都可用于保
- 爬取过程:你好,李焕英 短评的URL:https://movie.douban.com/subject/34841067/comments?
- 在python中利用numpy array进行数据处理,经常需要找出符合某些要求的数据位置,有时候还需要对这些位置重新赋值。这里总结了几种找
- 需求:取文件1中的一行,和文件2中所有的数据进行比较,有相同的保存起来,否则删除。#!/usr/bin/perl#use strict;op
- 一、推荐方法 CURL获取<?php$c = curl_init();$url = 'www.jb51.net';cu
- 目录1. 加载保存好的模型2. 使用flask起服务3. 发送请求并得到结果4. 效果呈现1. 加载保存好的模型为了方便起见,这里我们就使用
- 支付宝或者微信支付导出的收款二维码,除了二维码部分,还有很大一块背景图案,例如下面就是微信支付的收款二维码:有时候我们仅仅只想要图片中间的方