python使用tensorflow保存、加载和使用模型的方法
作者:LordofRobots 发布时间:2021-01-25 13:19:26
使用Tensorflow进行深度学习训练的时候,需要对训练好的网络模型和各种参数进行保存,以便在此基础上继续训练或者使用。介绍这方面的博客有很多,我发现写的最好的是这一篇官方英文介绍:
http://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/
我对这篇文章进行了整理和汇总。
首先是模型的保存。直接上代码:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut1_save.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 11:04:25
############################
import tensorflow as tf
# prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.Variable(tf.random_normal(shape = [2]), name = 'w1') # name is very important in restoration
w2 = tf.Variable(tf.random_normal(shape = [2]), name = 'w2')
b1 = tf.Variable(2.0, name = 'bias1')
feed_dict = {w1:[10,3], w2:[5,5]}
# define a test operation that will be restored
w3 = tf.add(w1, w2) # without name, w3 will not be stored
w4 = tf.multiply(w3, b1, name = "op_to_restore")
#saver = tf.train.Saver()
saver = tf.train.Saver(max_to_keep = 4, keep_checkpoint_every_n_hours = 1)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print sess.run(w4, feed_dict)
#saver.save(sess, 'my_test_model', global_step = 100)
saver.save(sess, 'my_test_model')
#saver.save(sess, 'my_test_model', global_step = 100, write_meta_graph = False)
需要说明的有以下几点:
1. 创建saver的时候可以指明要存储的tensor,如果不指明,就会全部存下来。在这里也可以指明最大存储数量和checkpoint的记录时间。具体细节看英文博客。
2. saver.save()函数里面可以设定global_step和write_meta_graph,meta存储的是网络结构,只在开始运行程序的时候存储一次即可,后续可以通过设置write_meta_graph = False加以限制。
3. 这个程序执行结束后,会在程序目录下生成四个文件,分别是.meta(存储网络结构)、.data和.index(存储训练好的参数)、checkpoint(记录最新的模型)。
下面是如何加载已经保存的网络模型。这里有两种方法,第一种是saver.restore(sess, 'aaaa.ckpt'),这种方法的本质是读取全部参数,并加载到已经定义好的网络结构上,因此相当于给网络的weights和biases赋值并执行tf.global_variables_initializer()。这种方法的缺点是使用前必须重写网络结构,而且网络结构要和保存的参数完全对上。第二种就比较高端了,直接把网络结构加载进来(.meta),上代码:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut2_import.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:16:38
############################
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('my_test_model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
print sess.run('w1:0')
使用加载的模型,输入新数据,计算输出,还是直接上代码:
#!/usr/bin/env python
#-*- coding:utf-8 -*-
############################
#File Name: tut3_reuse.py
#Author: Wang
#Mail: wang19920419@hotmail.com
#Created Time:2017-08-30 14:33:35
############################
import tensorflow as tf
sess = tf.Session()
# First, load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))
# Second, access and create placeholders variables and create feed_dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name('w1:0')
w2 = graph.get_tensor_by_name('w2:0')
feed_dict = {w1:[-1,1], w2:[4,6]}
# Access the op that want to run
op_to_restore = graph.get_tensor_by_name('op_to_restore:0')
print sess.run(op_to_restore, feed_dict) # ouotput: [6. 14.]
在已经加载的网络后继续加入新的网络层:
import tensorflow as tf
sess=tf.Session()
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}
#Now, access the op that you want to run.
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")
#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)
print sess.run(add_on_op,feed_dict)
#This will print 120.
对加载的网络进行局部修改和处理(这个最麻烦,我还没搞太明白,后续会继续补充):
......
......
saver = tf.train.import_meta_graph('vgg.meta')
# Access the graph
graph = tf.get_default_graph()
## Prepare the feed_dict for feeding data for fine-tuning
#Access the appropriate output for fine-tuning
fc7= graph.get_tensor_by_name('fc7:0')
#use this if you only want to change gradients of the last layer
fc7 = tf.stop_gradient(fc7) # It's an identity function
fc7_shape= fc7.get_shape().as_list()
new_outputs=2
weights = tf.Variable(tf.truncated_normal([fc7_shape[3], num_outputs], stddev=0.05))
biases = tf.Variable(tf.constant(0.05, shape=[num_outputs]))
output = tf.matmul(fc7, weights) + biases
pred = tf.nn.softmax(output)
# Now, you run this with fine-tuning data in sess.run()
有了这样的方法,无论是自行训练、加载模型继续训练、使用经典模型还是finetune经典模型抑或是加载网络跑前项,效果都是杠杠的。
来源:http://blog.csdn.net/LordofRobots/article/details/77719020
猜你喜欢
- 可实现功能:1.随机生成一个整数。2.随机生成任意范围内的一个整数。3.随机生成指定长度的整数组4.随机生成指定长度的任意范围的整数组5.随
- 一、用 ftplib 模块连接远程服务器ftplib模块常用方法ftp登陆连接from ftplib import FTP #加
- PHP原型模式Prototype Pattern是什么原型模式是一种创建型模式,它可以通过复制现有对象来创建新的对象,而无需知道具体的创建过
- memcache 的工作就是在专门的机器的内存里维护一张巨大的hash表,来存储经常被读写的一些数组与文件,从而极大的提高网站的运行效率,减
- 目录分析问题音频url搜索urlJS代码实现分析问题音频url点入某个音乐的播放界面,通过F12-Network,分析数据,可以看到有一个i
- 一.背景一道ctf题,通过破解2048游戏获得flag游戏的规则很简单,需要控制所有方块向同一个方向运动,两个相同数字方块撞在一起之后合并成
- 人的大脑通过双眼来辨别视觉图形获取信息。大脑根据储存的经验,将所看到的视觉图形建立起优先级。由此可见,一个良好的视觉设计可以帮助大脑迅速有效
- Fabric 是使用 Python 开发的一个自动化运维和部署项目的一个好工具,可以通过 SSH 的方式与远程服务器进行自动化交互,例如将本
- 1、打开本地企业管理器,先创建一个SQL Server注册来远程连接服务器端口SQL Server。步骤如下图:图1:2、弹出窗口后输入内容
- 对于DBA来说,监控磁盘使用情况是必要的工作,然后没有比较简单的方法能获取到磁盘空间使用率信息,下面总结下这些年攒下的脚本:最常用的查看磁盘
- 最近需要将实验数据画图出来,由于使用python进行实验,自然使用到了matplotlib来作图。下面的代码可以作为画图的模板代码,代码中有
- vue单页开发时经常需要父子组件之间传值,自己用过但是不是很熟练,这里我抽空整理了一下思路,写写自己的总结。GitHub地址:https:/
- QTableWidget介绍QTableWidget是Qt程序中常用的显示数据表格的控件,类似于c#中的DataGrid。QTableWid
- SQL Server中加密是层级的,每一个上层为下提供保护。如图:实例:/** SMK(Service Master Key)在SQL Se
- 一、VS2008工程设置工作 首先,建立一个windows应用程序的工程,将C/C++->预处理器->预处理器定义下的_WIND
- 在JavaScript中,我们应该尽可能的用局部变量来代替全局变量,这句话所有人都知道,可是这句话是谁先说的?为什么要这么做?有什么根据么?
- MySQL由于它本身的小巧和操作的高效, 在数据库应用中越来越多的被采用.我在开发一个P2P应用的时候曾经使用MySQL来保存P2P节点,由
- 本文实例为大家分享了mysql5.6.29的shell脚本,供大家参考,具体内容如下创建脚本mysql.sh,直接运行sh mysql.sh
- 阅读上一篇:javascript面向对象编程(一)[javascript模拟传统OOP]javascript是一种非常灵活的语言,它的灵活度
- 本文实例讲述了python中sleep函数用法。分享给大家供大家参考。具体如下:Python中的sleep用来暂停线程执行,单位为秒#---