对TensorFlow中的variables_to_restore函数详解
作者:修炼之路 发布时间:2022-09-11 00:49:19
标签:TensorFlow,variables,to,restore
variables_to_restore函数,是TensorFlow为滑动平均值提供。之前,也介绍过通过使用滑动平均值可以让神经网络模型更加的健壮。我们也知道,其实在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。
1、滑动平均值模型文件的保存
import tensorflow as tf
if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]
上面的代码,是如何来保存一个滑动平均值的模型文件,之前有介绍过滑动平均值和模型文件的保存,所以这里就不再重复了。
2、滑动平均值模型文件的读取
v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999
对于模型文件的读取,在上一篇博客中有介绍过,这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。是不是感觉使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。
3、variables_to_restore函数的使用
v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999
通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。
来源:https://blog.csdn.net/sinat_29957455/article/details/78508793


猜你喜欢
- <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN&
- 一、环境配置需要 pillow 和 pytesseract 这两个库,pip install 安装就好了。install pillow -i
- 靓丽的网页是怎样生成的?也许您会脱口而出,当然是自己设计出来的。没错!不过这其中也有网页制作工具的一部分功劳,因为功能强大的网页制作工具可以
- 在 MySQL 中,数据库和表对应于那些目录下的目录和文件。因而,操作系统的敏感性决定数据库和表命名的大小写敏感。这就意味着数据库和表名在
- 弄个随机数的东西,直接从网上找了一个现成的,简单看了两眼,感觉算法应该是对的,但今天测试下来,是不对的;网上大多数人用的写法是这样的:fun
- 1.保存变量先创建(在tf.Session()之前)saversaver = tf.train.Saver(tf.global_variab
- 一、下载SQL Server 2008 R2安装文件cn_sql_server_2008_r2_enterprise_x86_x64_ia6
- 1.背景项目需求,要求获得github的repo的api,以便可以提取repo的数据进行分析。研究了一天,终于解决了这个问题,虽然效率还是比
- 前言本文主要记录python下音频常用的操作,以.wav格式文件为例。其实网上有很多现成的音频工具包,如果仅仅调用,工具包是更方便的。更多p
- 一、什么是跨域?跨域问题的出现是因为浏览器的同源策略问题。所谓同源就是必须有以下三个相同点:协议相同、主机相同、端口相同。如果其中有一项不同
- Python 可以通过各种库去解析我们常见的数据。其中 csv 文件以纯文本形式存储表格数据,以某字符作为分隔值,通常为逗号;xml 可拓展
- 1、如何放弃正在输入的命令。 在输入一条比较长的命令时,出现打字错误是在所难免的。在这种情况下,放弃正在输入的命令重头再来往往会是更好的选择
- 识别验证码OCR(Optical Character Recognition)即光学字符识别技术,专门用于对图片文字进行识别,并获取文本。字
- <?php //包含一个计数器,一个提醒语句,用户ip以及自己的广告图片。 //给浏览器发送头,说我是张图片 Header
- sql 使用系统存储过程 sp_send_dbmail 发送电子邮件语法:sp_send_dbmail [ [ @profile_name
- D:document 文档 浏览器加载的页面 DOM O:object 对象 页面及页面中的任何元素都是对象 M:module 模型 页面中
- 实现效果:方法一:1 print "+"+"-"*8+"+"+"-&q
- 1.文件的写入和读取#!/usr/bin/python # -*- coding: utf-8 -*- # Filename: using_
- 安装过程询问一般 y 就可以了1 安装1.1 下载wget https://dev.mysql.com/get/mysql-apt-conf
- 目录1.1 题目1.2 思路1.2.1 发送请求1.2.2 解析网页1.2.3 获取结点1.2.4 数据保存 (单线程)1.2.4 数据保存