TensorFlow 滑动平均的示例代码
作者:叠加态的猫 发布时间:2023-10-25 15:41:28
滑动平均会为目标变量维护一个影子变量,影子变量不影响原变量的更新维护,但是在测试或者实际预测过程中(非训练时),使用影子变量代替原变量。
1、滑动平均求解对象初始化
ema = tf.train.ExponentialMovingAverage(decay,num_updates)
参数decay
`shadow_variable = decay * shadow_variable + (1 - decay) * variable`
参数num_updates
`min(decay, (1 + num_updates) / (10 + num_updates))`
2、添加/更新变量
添加目标变量,为之维护影子变量
注意维护不是自动的,需要每轮训练中运行此句,所以一般都会使用tf.control_dependencies使之和train_op绑定,以至于每次train_op都会更新影子变量
ema.apply([var0, var1])
3、获取影子变量值
这一步不需要定义图中,从影子变量集合中提取目标值
sess.run(ema.average([var0, var1]))
4、保存&载入影子变量
我们知道,在TensorFlow中,变量的滑动平均值都是由影子变量所维护的,如果你想要获取变量的滑动平均值需要获取的是影子变量而不是变量本身。
保存影子变量
建立tf.train.ExponentialMovingAverage对象后,Saver正常保存就会存入影子变量,命名规则是"v/ExponentialMovingAverage"对应变量”v“
import tensorflow as tf
if __name__ == "__main__":
v = tf.Variable(0.,name="v")
#设置滑动平均模型的系数
ema = tf.train.ExponentialMovingAverage(0.99)
#设置变量v使用滑动平均模型,tf.all_variables()设置所有变量
op = ema.apply([v])
#获取变量v的名字
print(v.name)
#v:0
#创建一个保存模型的对象
save = tf.train.Saver()
sess = tf.Session()
#初始化所有变量
init = tf.initialize_all_variables()
sess.run(init)
#给变量v重新赋值
sess.run(tf.assign(v,10))
#应用平均滑动设置
sess.run(op)
#保存模型文件
save.save(sess,"./model.ckpt")
#输出变量v之前的值和使用滑动平均模型之后的值
print(sess.run([v,ema.average(v)]))
#[10.0, 0.099999905]
载入影子变量并映射到变量
利用了Saver载入模型的变量名映射功能,实际上对所有的变量都可以如此操作『TensorFlow』模型载入方法汇总
v = tf.Variable(1.,name="v")
#定义模型对象
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
sess = tf.Session()
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999
这里特别需要注意的一个地方就是,在使用tf.train.Saver函数中,所传递的模型参数是{"v/ExponentialMovingAverage":v}而不是{"v":v},如果你使用的是后面的参数,那么你得到的结果将是10而不是0.09,那是因为后者获取的是变量本身而不是影子变量。
使用这种方式来读取模型文件的时候,还需要输入一大串的变量名称。
variables_to_restore函数的使用
v = tf.Variable(1.,name="v")
#滑动模型的参数的大小并不会影响v的值
ema = tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
#{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
sess = tf.Session()
saver = tf.train.Saver(ema.variables_to_restore())
saver.restore(sess,"./model.ckpt")
print(sess.run(v))
#0.0999999
variables_to_restore会识别网络中的变量,并自动生成影子变量名。
通过使用variables_to_restore函数,可以使在加载模型的时候将影子变量直接映射到变量的本身,所以我们在获取变量的滑动平均值的时候只需要获取到变量的本身值而不需要去获取影子变量。
5、官方文档例子
官方文档中将每次apply更新就会自动训练一边模型,实际上可以反过来两者关系,《tf实战google》P128中有示例
| Example usage when creating a training model:
|
| ```python
| # Create variables.
| var0 = tf.Variable(...)
| var1 = tf.Variable(...)
| # ... use the variables to build a training model...
| ...
| # Create an op that applies the optimizer. This is what we usually
| # would use as a training op.
| opt_op = opt.minimize(my_loss, [var0, var1])
|
| # Create an ExponentialMovingAverage object
| ema = tf.train.ExponentialMovingAverage(decay=0.9999)
|
| with tf.control_dependencies([opt_op]):
| # Create the shadow variables, and add ops to maintain moving averages
| # of var0 and var1. This also creates an op that will update the moving
| # averages after each training step. This is what we will use in place
| # of the usual training op.
| training_op = ema.apply([var0, var1])
|
| ...train the model by running training_op...
| ```
6、batch_normal的例子
和上面不太一样的是,batch_normal中不太容易绑定到train_op(位于函数体外面),则强行将两个variable的输出过程化为节点,绑定给参数更新步骤
def batch_norm(x,beta,gamma,phase_train,scope='bn',decay=0.9,eps=1e-5):
with tf.variable_scope(scope):
# beta = tf.get_variable(name='beta', shape=[n_out], initializer=tf.constant_initializer(0.0), trainable=True)
# gamma = tf.get_variable(name='gamma', shape=[n_out],
# initializer=tf.random_normal_initializer(1.0, stddev), trainable=True)
batch_mean,batch_var = tf.nn.moments(x,[0,1,2],name='moments')
ema = tf.train.ExponentialMovingAverage(decay=decay)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean,batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean),tf.identity(batch_var)
# identity之后会把Variable转换为Tensor并入图中,
# 否则由于Variable是独立于Session的,不会被图控制control_dependencies限制
mean,var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean),ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, eps)
return normed
来源:http://www.cnblogs.com/hellcat/p/8583379.html
猜你喜欢
- 1. 概述Python中 asyncio 模块内置了对异步IO的支持,用于处理异步IO;是Python 3.4版本引入的标准库。asynci
- function getHTTPRequest() { var xhr = false; if (window.XMLHttpRequest
- 一丶什么是索引索引是存储引擎快速找到记录的一种数据结构。数据库中的数据可以理解成字典中的单词,而索引就是目录,显而易见这是一种空间换时间的做
- 由于工作需要,这两天在看GOOGLE MAP 的 API,需要在公司的网站上使用地图。今天把看过之后的一点使用方法,跟大家一起分享:演示地址
- 一个完整的数据挖掘模型,最后都要进行模型评估,对于二分类来说,AUC,ROC这两个指标用到最多,所以 利用sklearn里面相应的函数进行模
- 1. 将下载好的字体放到本地目录分别是两种字体放到项目的 assets 目录中2. 引入字体文件首先创建一个 styles 文件
- 空白双边距是一个极容易误解的CSS特性.它不是CSS的bug,但如果我们一旦误解,将会给你带来很多麻烦.先看如下demo代码:<!do
- 导语幼儿园升小学,小学升中学,中学升高中..........每个人都要经历的九年义务教育:伴随的都是作业、随堂考、以及每个科目的大大小小的考
- 建立一个数据库表维护规范在一个定期基础而非等到问题出现才实施数据库表的检查是一个好主意。应该考虑到建立一个预防性维护的时间表,以协助自动问题
- random() 方法返回随机生成的一个实数,它在[0,1)范围内。random()返回随机生成的一个实数,范围在[0,1)之间语
- import pyspherefrom pysphere import VIServerhost_ip = "200.200.17
- 一、c++调用Python将Python安装目录下的include和libs文件夹引入到项目中,将libs目录下的python37.lib复
- 大多数情况下,我们使用 webpack来打包单页应用程序,这个时候只需要配置一个入口,一个模板文件,但也不尽是如此,有时候也会碰到多页面的项
- 现如今,各个国家交流密切,通过翻译使我们打破了语言壁垒,而翻译在互联网上的存在也尤为普遍。python中执行翻译操作的包是translate
- Python列表和字典前面我们了解了 “大O表示法” 以及对不同的算法的评估,下面来讨论下 Python 两种内置数据类型有关的各种操作的大
- python的os module中有fork()函数用于生成子进程,生成的子进程是父进程的镜像,但是它们有各自的地址空间,子进程复制一份父进
- --返回由备份集内包含的数据库和日志文件列表组成的结果集。 --主要获得逻辑文件名 restore filelistonly from di
- GIT安装访问: https://git-scm.com/downloads ,进入git'下载页面,根据个人操作系统下载对应软件版
- 今天我们来探索python中大部分的异常报错首先异常是什么,异常白话解释就是不正常,程序里面一般是指程序员输入的格式不规范,或者需求的参数类
- 概述介绍触发器,就是一种特殊的存储过程。触发器和存储过程一样是一个能够完成特定功能、存储在数据库服务器上的SQL片段,但是触发器无需调用,当