tensorflow如何继续训练之前保存的模型实例
作者:by_side_with_sun 发布时间:2023-05-22 22:54:57
标签:tensorflow,训练,模型
一:需重定义神经网络继续训练的方法
1.训练代码
import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
loss=tf.reduce_mean(tf.square(y-y_data)) #loss
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
sess.run(train)
saver.save(sess,"./save_mode",global_step=step) #保存
print("当前进行:",step)
第一次训练截图:
2.恢复上一次的训练
import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
y=weight*x_data+biases
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
sess.run(train)
saver.save(sess,r"./save_new_mode",global_step=step)
print("当前进行:",step," ",sess.run(weight),sess.run(biases))
使用上次保存下的数据进行继续训练和保存:
#最后要提一下的是:
checkpoint文件
meta保存了TensorFlow计算图的结构信息
datat保存每个变量的取值
index保存了 表
加载restore时的文件路径名是以checkpoint文件中的“model_checkpoint_path”值决定的
这个方法需要重新定义神经网络
二:不需要重新定义神经网络的方法:
在上面训练的代码中加入:tf.add_to_collection("name",参数)
import numpy as np
import tensorflow as tf
x_data=np.random.rand(100).astype(np.float32)
y_data=x_data*0.1+0.3
weight=tf.Variable(tf.random_uniform([1],-1.0,1.0),name="w")
biases=tf.Variable(tf.zeros([1]),name="b")
y=weight*x_data+biases
loss=tf.reduce_mean(tf.square(y-y_data))
optimizer=tf.train.GradientDescentOptimizer(0.5)
train=optimizer.minimize(loss)
tf.add_to_collection("new_way",train)
init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
sess.run(train)
saver.save(sess,"./save_mode",global_step=step)
print("当前进行:",step)
在下面的载入代码中加入:tf.get_collection("name"),就可以直接使用了
import numpy as np
import tensorflow as tf
sess=tf.Session()
saver=tf.train.import_meta_graph(r'save_mode-9.meta')
saver.restore(sess,tf.train.latest_checkpoint(r'./'))
print(sess.run("w:0"),sess.run("b:0"))
graph=tf.get_default_graph()
weight=graph.get_tensor_by_name("w:0")
biases=graph.get_tensor_by_name("b:0")
y=tf.get_collection("new_way")[0]
saver=tf.train.Saver(max_to_keep=0)
for step in range(10):
sess.run(y)
saver.save(sess,r"./save_new_mode",global_step=step)
print("当前进行:",step," ",sess.run(weight),sess.run(biases))
总的来说,下面这种方法好像是要便利一些
来源:https://blog.csdn.net/by_side_with_sun/article/details/79829619


猜你喜欢
- 在使用Golang的时候,不免会使用Json和结构体的相互转换,这时候常用的就是 json.Marshal 和 json.Unmarshal
- 目录问题描述大致的功能效果有如下思路分析完整代码总结问题描述teambition软件是企业办公协同软件,相信部分朋友的公司应该用过这款软件。
- 一、网络知识的一些介绍 socket 是网络连接端点。例如当你的Web浏览器请求www.jb51.net上的主页时,你的Web浏览器创建一个
- 如下所示:import matplotlib.pyplot as pltimport numpy as npa = np.array([1,
- 代码如下import pandas as pdimport matplotlib.pyplot as pltimport numpy as
- 远程连接SQL Server 2008,服务器端和客户端配置关键设置:第一步(SQL2005、SQL2008):开始-->程序--&g
- Python下一切皆对象,每个对象都有多个属性(attribute),Python对属性有一套统一的管理方案。__dict__与dir()的
- <input type=button value=刷新 onclick="window.location.reload()&
- code:f = open('yesterday','r',encoding='utf-8'
- 一、说在前面 需求:有一张长为960,宽为96的图片,需要将其分割成10张96*96的图
- 春节休息了几天,今天上班第一天,最近混twitter混得比较多,经常要压缩URL,以前做了个书签用http://is.gd/压缩,后来发现了
- Delphi连接MySQL真麻烦,研究了一天,从网上找了无数文章,下载了无数插件都没解决。最后返璞归真,老老实实用ADO来连接,发现也不是很
- @using@using 指令用于向生成的视图添加 C# using 指令:@using System.IO@{
- 这几天正在追剧,原名《大秦帝国之天下》的《大秦赋》,看着看着又想把前几部刷一遍了,但第一部《裂变》自己没有高清资源,搜了一波发现yout
- python中for循环用于针对集合中的每个元素的一个代码块,而while循环能实现满足条件下的不断运行。使用while循环时,由于whil
- 有一个需求是要在一个云监控的状态值中存储多个状态(包括可同时存在的各种异常、警告状态)使用了位运算机制在一个int型中存储。现在监控日志数据
- 故障描述percona5.6,mysqldump全备份,导入备份数据时报错Duplicate entry 'hoc_log99-it
- 支持多种编码的中文字符串截取函数! /* * @todo&
- 场景:mysql统计一个数据库里所有表的数据量,最近在做统计想查找一个数据库里基本所有的表数据量,数据量少的通过select count再加
- 目录前言示例文件文件编码空值日期错误函数映射方法1:直接使用labmda表达式方法二:使用自定义函数方法三:使用数值字典映射总结前言本文是给