Tensorflow之Saver的用法详解
作者:陈浅墨 发布时间:2023-10-01 22:40:07
Saver的用法
1. Saver的背景介绍
我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。
2. Saver的实例
下面以一个例子来讲述如何使用Saver类
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
train_steps:表示训练的次数,例子中使用100
checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径
2.1 训练阶段
使用Saver.save()方法保存模型:
sess:表示当前会话,当前会话记录了当前的变量值
checkpoint_dir + 'model.ckpt':表示存储的文件名
global_step:表示当前是第几步
训练完成后,当前目录底下会多出5个文件。
打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。
2.1测试阶段
测试阶段使用saver.restore()方法恢复变量:
sess:表示当前会话,之前保存的结果将被加载入这个会话
ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。
运行结果如下图所示,加载了之前训练的参数w和b的结果
来源:https://blog.csdn.net/u011500062/article/details/51728830


猜你喜欢
- 多进程共享变量和获得结果由于工程需求,要使用多线程来跑一个程序。但是因为听说python的多线程是假的,于是使用多进程,反正任务需要共享的参
- 多个值合并展示现在我们有如图一到图二的需求怎么做?如下sql:SELECT id,GROUP_CONCAT(DISTINCT str) as
- 翻译整理:Young.J;官方网站:http://jquery.comjQuery是一款同prototype一样优秀js开发库类,特别是对c
- 前言本文介绍在 pandas 中如何读取数据行列的方法。数据由行和列组成,在数据库中,一般行被称作记录 (record),列被称作字段 (f
- 本文实例讲述了Centos7.4环境安装lamp-php7.0的方法。分享给大家供大家参考,具体如下:一. 环境准备桥接模式能访问外网#pi
- 正则表达式是一个特殊的字符序列,它能帮助你方便的检查一个字符串是否与某种模式匹配。Python 自1.5版本起增加了re 模块,它提供 Pe
- 在实际的项目中,我们一般都会建立三个环境:开发、测试和生产环境,这三种环境会使用不同的配置组合,为了能方便地切换配置,我们可以为不同的环境创
- 笔者通过一周的时间,询问了许多设计人员真实用户,以便确保这六个方面确实是大多数用户所不喜并且有非常大的概率普遍存在于众多的医疗网站之中。那么
- 方式一:数据库用的是SQL 2008,数据表中存放的是图片的二进制数据,现在把图片以一种图片格式(如.jpg)导出,然后存放于指定的文件夹中
- 在 Python 中一切都可以看作为对象。每个对象都有各自的 id, type 和 value。id: 当一个对象被创建后,它的 id 就不
- 前言由于总是出错,记录一下连接MySQL数据库的过程。连接过程1.下载MySQL并安装,这里的版本是8.0.182.下载MySQL的jdbc
- 库的管理1、库的管理创建、修改、删除1、库的创建CREATE DATABASE UF NOT EXISTS books;2、库的修改库名一般
- Vue动态创建组件实例并挂载到body方式一import Vue from 'vue'/** * @param Compon
- 1、先说结论:使用xml-rpc的机制可以很方便的实现服务器间的RPC调用。2、试验结果如下:3、源码如下:服务器端的源代码如下:impor
- 前言我们都知道时区,标准时区是UTC时区,django默认使用的就是UTC时区,所以我们存储在数据库中的时间是UTC的时间,但是当我们做的网
- 本文实例讲述了Python中绑定与未绑定的类方法。分享给大家供大家参考,具体如下:像函数一样,Python中的类方法也是一种对象。由于既可以
- 一、报错信息:【file】【Default Settint】---Project Interpreter 点击搜索suds安装模块报错解决:
- 首先对图片进行预处理,是图片的分配率大小在合适的范围内,避免图片太大占满整个电脑屏幕。from PIL import Imagedef pr
- 一、使用等号查询可以像普通查询使用等号进行查询,但必须查询时间必须和字段对应时间完全相等,比如我要查下面这个值sql如下:SELECT id
- 最近想把word密码文件的服务器密码信息归档到mysql数据库,心想着如果直接在里面写明文密码会不会不安全,如果用sha这些不可逆的算法又没