TensorFlow入门使用 tf.train.Saver()保存模型
作者:永永夜 发布时间:2023-06-07 15:29:45
关于模型保存的一点心得
saver = tf.train.Saver(max_to_keep=3)
在定义 saver 的时候一般会定义最多保存模型的数量,一般来说,如果模型本身很大,我们需要考虑到硬盘大小。如果你需要在当前训练好的模型的基础上进行 fine-tune,那么尽可能多的保存模型,后继 fine-tune 不一定从最好的 ckpt 进行,因为有可能一下子就过拟合了。但是如果保存太多,硬盘也有压力呀。如果只想保留最好的模型,方法就是每次迭代到一定步数就在验证集上计算一次 accuracy 或者 f1 值,如果本次结果比上次好才保存新的模型,否则没必要保存。
如果你想用不同 epoch 保存下来的模型进行融合的话,3到5 个模型已经足够了,假设这各融合的模型成为 M,而最好的一个单模型称为 m_best, 这样融合的话对于M 确实可以比 m_best 更好。但是如果拿这个模型和其他结构的模型再做融合的话,M 的效果并没有 m_best 好,因为M 相当于做了平均操作,减少了该模型的“特性”。
但是又有一种新的融合方式,就是利用调整学习率来获取多个局部最优点,就是当 loss 降不下了,保存一个 ckpt, 然后开大学习率继续寻找下一个局部最优点,然后用这些 ckpt 来做融合,还没试过,单模型肯定是有提高的,就是不知道还会不会出现上面再与其他模型融合就没提高的情况。
如何使用 tf.train.Saver() 来保存模型
之前一直出错,主要是因为坑爹的编码问题。所以要注意文件的路径绝对不不要出现什么中文呀。
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)
Model saved in file: ./ckpt/test-model.ckpt-1
注意,在上面保存完了模型之后。应该把 kernel restart 之后才能使用下面的模型导入。否则会因为两次命名 “v1” 而导致名字错误。
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")
print sess.run(v1)
print sess.run(v2)
INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1. 2.29999995]
55.5
导入模型之前,必须重新再定义一遍变量。
但是并不需要全部变量都重新进行定义,只定义我们需要的变量就行了。
也就是说,你所定义的变量一定要在 checkpoint 中存在;但不是所有在checkpoint中的变量,你都要重新定义。
import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")
print sess.run(v1)
INFO:tensorflow:Restoring parameters from ./ckpt/test-model.ckpt-1
Model restored.
[ 1. 2.29999995]
tf.Saver([tensors_to_be_saved]) 中可以传入一个 list,把要保存的 tensors 传入,如果没有给定这个list的话,他会默认保存当前所有的 tensors。一般来说,tf.Saver 可以和 tf.variable_scope() 巧妙搭配,可以参考: 【迁移学习】往一个已经保存好的模型添加新的变量并进行微调
来源:https://blog.csdn.net/Jerr__y/article/details/78594494


猜你喜欢
- docx2txt的Github地址docx2txt是基于python的从docx文件中提取文本和图片的库。代码是从python-docx中获
- APSchedulerAPScheduler 四个组件分别为:调度器(scheduler)、触发器(trigger),作业存储(job st
- 之前有看过一个博文写的是白社会的设计很好但运营却有些遭,因为对每一个WebGame的推出时间把握不准,会有几个应用同时上线造成影响力的冲突,
- fileno()方法返回所使用的底层实现,要求从操作系统I/O操作的整数文件描述符。语法以下是fileno()方法的语法:fil
- 本文实例讲述了python写日志文件操作类与应用。分享给大家供大家参考,具体如下:项目的开发过程中,日志文件是少不了的,通过写日志文件,可以
- (一)RabbitMQ的简介RabbitMq 是实现了高级消息队列协议(AMQP)的开源消息代理中间件。消息队列是一种应用程序对应用程序的通
- 最近跟着OpenCV2-Python-Tutorials在学习python_opencv中直方图的反向投影时,第一种方法是使用numpy实现
- 常用快捷键全部快捷键1、编辑(Editing)2、查找/替换(Search/Replace)3、运行(Running)4、调试(Debugg
- 在 Go 里面的协程执行实际上默认是没有严格的先后顺序的。由于 Go 语言 GPM 模型的设计理念,真正执行实际工作的实际上是 GPM 中的
- 语法格式:row_number() over(partition by 分组列 order by 排序列 desc)row_num
- 区别与联系: 1、get是从服务器上获取数据,post则是向服务器传送数据; 2、get将表单中数据的按照variable=value的 形
- 最近在学习python著名的绘图包matplotlib时发现,有时候图例等设置无法正常显示中文,于是就想把这个问题解决了。PS:本文仅针对W
- import模块时有错误红线的解决 前情提要概念:在一个文件中代码越长越不容易维护,为了编写可维护的代码,我们把很多函数分组,分别
- 用过mac的朋友都反映很好用,不仅美观,性能好,关键是他的系统底层对于开发人员来说,无疑就是一个最大的好处,用习惯linux的人就知道mac
- list/tuple转置:以二维grid[][]为例:grid = [[row[i] for row in grid] for i in r
- 线程实现Python中线程有两种方式:函数或者用类来包装线程对象。threading模块中包含了丰富的多线程支持功能:threading.c
- 需求:项目中需要把链接地址生成二维码,用户扫描二维码就可以打开页面实现如下:使用了vue-qriously插件使用步骤:安装npm inst
- 引言通常,您可能希望在 Pandas DataFrame 中插入一个新列。幸运的是,使用 pandas insert()函数很容易做到这一点
- Django的View一个视图函数(类),简称视图,是一个简单的Python 函数(类),它接受Web请求并且返回Web响应。响应可以是一张
- 序列化是将对象状态转换为可保持或传输的格式的过程。与序列化相对的是反序列化,它将流转换为对象。这两个过程结合起来,可以轻松地存储和传输数据方