tensorflow模型继续训练 fineturn实例
作者:-牧野- 发布时间:2023-07-10 12:53:09
标签:tensorflow,模型,训练,fineturn
解决tensoflow如何在已训练模型上继续训练fineturn的问题。
训练代码
任务描述: x = 3.0, y = 100.0, 运算公式 x×W+b = y,求 W和b的最优解。
# -*- coding: utf-8 -*-)
import tensorflow as tf
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
# 声明变量
W = tf.Variable(tf.zeros([1, 1]),name='w')
b = tf.Variable(tf.zeros([1]),name='b')
# 操作
result = tf.matmul(x, W) + b
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(max_to_keep=3)
# 这里x、y给固定的值
x_s = [[3.0]]
y_s = [[100.0]]
step = 0
while (True):
step += 1
feed = {x: x_s, y: y_s}
# 通过sess.run执行优化
sess.run(train_step, feed_dict=feed)
if step % 1000 == 0:
print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
print ''
# print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
print 'final result of {0} = {1}(目标值是100.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
print ''
print("模型保存的W值 : %f" % sess.run(W))
print("模型保存的b : %f" % sess.run(b))
break
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型
训练完成之后生成模型文件:
训练输出:
step: 1000, loss: 4.89526428282e-08
step: 2000, loss: 4.89526428282e-08
step: 3000, loss: 4.89526428282e-08
step: 4000, loss: 4.89526428282e-08
step: 5000, loss: 4.89526428282e-08
final result of x×W+b = [[99.99978]](目标值是100.0)
模型保存的W值 : 29.999931
模型保存的b : 9.999982
保存在模型中的W值是 29.999931,b是 9.999982。
以下代码从保存的模型中恢复出训练状态,继续训练
任务描述: x = 3.0, y = 200.0, 运算公式 x×W+b = y,从上次训练的模型中恢复出训练参数,继续训练,求 W和b的最优解。
# -*- coding: utf-8 -*-)
import tensorflow as tf
# 声明占位变量x、y
x = tf.placeholder("float", shape=[None, 1])
y = tf.placeholder("float", [None, 1])
with tf.Session() as sess:
# 初始化变量
sess.run(tf.global_variables_initializer())
# saver = tf.train.Saver(max_to_keep=3)
saver = tf.train.import_meta_graph(r'./save_model/re-train-5000.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")
print("从保存的模型中恢复出来的W值 : %f" % sess.run("w:0"))
print("从保存的模型中恢复出来的b值 : %f" % sess.run("b:0"))
# 操作
result = tf.matmul(x, W) + b
# 损失函数
lost = tf.reduce_sum(tf.pow((result - y), 2))
# 优化
train_step = tf.train.GradientDescentOptimizer(0.0007).minimize(lost)
# 这里x、y给固定的值
x_s = [[3.0]]
y_s = [[200.0]]
step = 0
while (True):
step += 1
feed = {x: x_s, y: y_s}
# 通过sess.run执行优化
sess.run(train_step, feed_dict=feed)
if step % 1000 == 0:
print 'step: {0}, loss: {1}'.format(step, sess.run(lost, feed_dict=feed))
if sess.run(lost, feed_dict=feed) < 1e-10 or step > 4e3:
print ''
# print 'final loss is: {}'.format(sess.run(lost, feed_dict=feed))
print 'final result of {0} = {1}(目标值是200.0)'.format('x×W+b', 3.0 * sess.run(W) + sess.run(b))
print("模型保存的W值 : %f" % sess.run(W))
print("模型保存的b : %f" % sess.run(b))
break
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型
训练输出:
从保存的模型中恢复出来的W值 : 29.999931
从保存的模型中恢复出来的b值 : 9.999982
step: 1000, loss: 1.95810571313e-07
step: 2000, loss: 1.95810571313e-07
step: 3000, loss: 1.95810571313e-07
step: 4000, loss: 1.95810571313e-07
step: 5000, loss: 1.95810571313e-07
final result of x×W+b = [[199.99956]](目标值是200.0)
模型保存的W值 : 59.999866
模型保存的b : 19.999958
从保存的模型中恢复出来的W值是 29.999931,b是 9.999982,跟模型保存的值一致,说明加载成功。
总结
从头开始训练一个模型,需要通过 tf.train.Saver创建一个保存器,完成之后使用save方法保存模型到本地:
saver = tf.train.Saver(max_to_keep=3)
……
saver.save(sess, "./save_model/re-train", global_step=step) # 保存模型
在训练好的模型上继续训练,fineturn一个模型,可以使用tf.train.import_meta_graph方法加载图结构,使用restore方法恢复训练数据,最后使用同样的save方法保存到本地:
saver = tf.train.import_meta_graph(r'./save_model/re-train-10050.meta') # 加载模型图结构
saver.restore(sess, tf.train.latest_checkpoint(r'./save_model')) # 恢复数据
saver.save(sess, "./save_mode/re-train", global_step=step) # 保存模型
注:特殊情况下(如本例)需要从恢复的模型中加载出数据:
# 从保存模型中恢复变量
graph = tf.get_default_graph()
W = graph.get_tensor_by_name("w:0")
b = graph.get_tensor_by_name("b:0")
来源:https://blog.csdn.net/dcrmg/article/details/83031488


猜你喜欢
- callable函数可用于判断一个对象是否可以被调用,若对象可以被调用则返回True,反之则返回False。所谓可调用,是指代码里可以在对象
- 一、业务需求在使用Python进行业务开发的时候,需要将一些数据保存到本地文件存储,方便后面进行数据分析展示。二、需求分析通过查看需求可得出
- 这篇文章主要介绍了python Jupyter运行时间实例过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 1、拆箱>>> a, b, c = 1, 2, 3>>> a, b, c(1, 2, 3)>>
- 这个原因很简单,就是你没有在相应的表单信息中写入name属性。例如:<tr> <t
- 1. base64编码简介用记事本打开exe、jpg、pdf这些文件时,我们都会看到一大堆乱码,因为二进制文件包含很多无法显示和打印的字符,
- 使用springboot开发时,默认使用内置的tomcat数据库连接池,经常碰到这种情况:运行时间一长,数据库连接中断了。所以使用c3p0连
- fallthrough:Go里面switch默认相当于每个case最后带有break,匹配成功后不会自动向下执行其他case,而是跳出整个s
- vue动态添加store,路由和国际化vue动态添加store想写组件库?用这个吧 …// store module标
- 列表列表是Python中最具灵活性的有序集合对象类型。与字符串不同的是,列表可以包含任何类型的对象:数字、字符串甚至其他列表。列表是可变对象
- 当你检查scrapy二进制文件时,你会注意到这么一段python script#!/usr/bin/pythonfrom scrapy.cm
- 本文实例为大家分享了vue实现页面添加水印的具体代码,供大家参考,具体内容如下js文件建一个watermark.js文件let setWat
- 原文地址:http://ilovetypography.com/2007/10/22/so-you-want-to-create-a-fon
- 本文实例总结了python获取外网ip地址的方法。分享给大家供大家参考。具体如下:一、利用脚本引擎库直接获取import console;i
- python代码生成API接口如果要将我们写好的Python代码生成API接口时,我们需要借助Flask框架1. 安装Flaskpip in
- 这一部分我们将探索 PyQt5 的事件和信号是如何在应用程序中实现的。Events事件所有的GUI应用程序都是事件驱动的。应用程序事件主要产
- 目标在本章中,将学习利用calib3d模块在图像中创建一些3D效果基础在上一节相机校准中,了解了相机矩阵、失真系数等。给定图案图像,可以利用
- 一、总结说明Windows环境安装:paramunittest cmd输入命令:pip install paramunittest总结说明:
- 问:如何在SQL Enterprise Manager version 6.5下操作SQL Server 6.0的服务器?答:在使用SQL
- 效果知识点:css3画气球, 自定义属性运用,随机阵列, DOM元素操作,高级回调函数与参数复传,动态布局,鼠标事件,定时器运用,CSS3新