tensorflow 保存模型和取出中间权重例子
作者:binqiang2wang 发布时间:2021-05-11 07:30:11
标签:tensorflow,保存模型,权重
下面代码的功能是先训练一个简单的模型,然后保存模型,同时保存到一个pb文件当中,后续可以从pd文件里读取权重值。
import tensorflow as tf
import numpy as np
import os
import h5py
import pickle
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
#设置使用指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#下面这段代码是在训练好之后将所有的权重名字和权重值罗列出来,训练的时候需要注释掉
reader = tf.train.NewCheckpointReader('./model.ckpt-100')
variables = reader.get_variable_to_shape_map()
for ele in variables:
print(ele)
print(reader.get_tensor(ele))
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False#设成True去训练模型
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
graph_def = tf.get_default_graph().as_graph_def()
#通过修改下面的函数,个人觉得理论上能够实现修改权重,但是很复杂,如果哪位有好办法,欢迎指教
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['Variable'])
with tf.gfile.FastGFile('./test.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
with tf.Session() as sess:
#对应最后一部分的写,这里能够将对应的变量取出来
with gfile.FastGFile('./test.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
res = tf.import_graph_def(graph_def, return_elements=['Variable:0'])
print(sess.run(res))
print(sess.run(graph_def))
来源:https://blog.csdn.net/m0_37052320/article/details/79845537
0
投稿
猜你喜欢
- 一个不错的js效果,实现了图片预加载,并实时显示图片加载进度。<script> var l=0; var i
- 1、查找表结构,判断要加入的列是否已存在2、如果不存在,则执行添加 CREATE PROCEDURE `mysql_sp_add_
- 目录1.函数的介绍2.函数的定义和调用3.函数的参数4.参数的分类4.1.位置参数4.2.关键字参数4.3.缺省参数4.4.不定长参数1.不
- <?php// 使用Memache 作为进程锁 class lock_processlock{// key 的前缀protected
- 大家一定使用过 phpmyadmin 里面的数据库导入,导出功能,非常方便。但是在实际应用中,我发现如下几个问题:1、数据库超过一定尺寸,比
- 问题:pycharm无法调用pip安装的包原因:pycharm没有设置解析器解决方法:打开pycharm->File->Sett
- 效果图:css:<style type="text/css"> /* 带复选框的下拉框 */ ul li{
- 编写思路:把本地文件在客户端通过base64编码以后发送目的地.测试过程中,上传文件过大,导致超时不成功,后来经过改善.把编码分
- 科学设计你的网站网页:来自 Eye-Tracking研究的23节必修课 ——Christina Laun在网络设计领域关于Eye-
- Python 是支持面向对象的,很多情况下使用面向对象编程会使得代码更加容易扩展,并且可维护性更高,但是如果你写的多了或者某一对象非常复杂了
- 当我们使用访问一个没有声明的变量时,JS会报错;而当我们给一个没有声明的变量赋值时,JS不会报错,相反它会认为我们是要隐式申明一个全局变量。
- 注意: 在搭建网络的时候用carpool2D的时候,让高度和宽度方向不同池化时,用如下:nn.MaxPool2d(kernel_size=2
- Access数据库,同时操作大量记录(9500条以上)时报错。错误提示:Microsoft JET Database Engine 错误 &
- 如代码1所示: // 代码 1 // 外观层类 class LWordHomePage { // 添加留言 public function
- 要真说出来哪一个函数能够做得到,还真难。但我们可用下面的代码来进行识别,返回“假”即偶数,返回“真”则奇数: function&n
- 前言数据驱动是一种思想,让数据和代码进行分离,比如爬虫时,我们需要分页爬取数据时,我们往往把页数 page 参数化,放在 for 循环 ra
- 方法1:1.安装requests_toolbelt依赖库#代码实现def upload(self): login_
- 这篇文章主要介绍了python批量启动多线程代码实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友
- 一、介绍实现的是把某个文件夹下的所有文件名提取出来,放入一个列表,在与excel中的某列进行对比,如果一致的话,对另一列进行操作,比如我们在
- 关于在asp中不使用组件使得脚本sleep的办法还比较少见,可能比较好的办法是创建同步的xmlhttp request,直到获得的时间达到某