TensorFlow的权值更新方法
作者:朂嘼 发布时间:2022-12-24 21:41:08
一. MovingAverage权值滑动平均更新
1.1 示例代码:
def create_target_q_network(self,state_dim,action_dim,net):
state_input = tf.placeholder("float",[None,state_dim])
action_input = tf.placeholder("float",[None,action_dim])
ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
target_update = ema.apply(net)
target_net = [ema.average(x) for x in net]
layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])
return state_input,action_input,q_value_output,target_update
def update_target(self):
self.sess.run(self.target_update)
其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )
第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;
第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。
更新公式为:
shadow_variable = decay * shadow_variable + (1 - decay) * variable
在上述代码段中,target_net是shadow_variable,net是variable
1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)
var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。
函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。
1.3 tf.train.ExponentialMovingAverage.average(var)
用于获取var的滑动平均结果。
二. tf.train.Optimizer更新网络权值
2.1 tf.train.Optimizer
tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。
此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:
1.利用tf.train.Optimizer.compute_gradients计算梯度
2.对梯度进行自定义处理
3.利用tf.train.Optimizer.apply_gradients更新权值
tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)
返回一个(梯度,权值)的列表对。
tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)
返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)
2.2 其它
此外,tensorflow还提供了其它计算梯度的方法:
• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)
该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。
其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。
• tf.stop_gradient(input, name=None)
该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:
cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|
可以事先声明
y=tf.stop_gradient(r+Q next r+Qnext)
来源:https://blog.csdn.net/GH234505/article/details/54976696
猜你喜欢
- 介绍:细处着手,巧处用功。高手和菜鸟之间的差别就是:高手什么都知道,菜鸟知道一些。电脑小技巧收集最新奇招高招,让你轻松踏上高手之路。摘要:
- 1、获取秒级时间戳与毫秒级时间戳、微秒级时间戳import timeimport datetimet = time.time()print
- 一、摘要Python使用被称为异常 的特殊对象来管理程序执行期间发生的错误。每当发生让Python不知所措的错误时,它都会创建一个异常对象。
- 使用python的json模块序列化时间或者其他不支持的类型时会抛异常,例如下面的代码:# -*- coding: cp936 -*-fro
- 本文以实例演示5种验证码,并介绍生成验证码的函数。PHP生成验证码的原理:通过GD库,生成一张带验证码的图片,并将验证码保存在Session
- if(document.mylist.length != "undefined" ) {} 这个用法有误. 正确的是 i
- 相信很多人在使用正则表达式的时候都会遇到如下的语句:通过查阅正则表达式的API文档可以了解到正则表达式的语法知识:很多小伙伴就会产生疑问为什
- Dynaconf 是一个 Python 的第三方模块,旨在成为在 Python 中管理配置的最佳选择。它可以从各种来源读取设置,包括环境变量
- fsockopen函数能够运用,首先要开启php.ini中的allow_url_open=on;fsockopen是对socket客户端代码
- 首先看这下面的例子(鼠标移上去):<TABLE><TBODY><TR&g
- 本文实例为大家分享了python实现学生信息管理系统的具体代码,供大家参考,具体内容如下1.主要内容python种的.py文件如图所示第一个
- Python import的搜索路径import的搜索路径为:搜索「内置模块」(built-in module)搜索 sys.path 中的
- 被分割的字段一定是有限而且数量较少的,我们不可能在一个字符串中存储无限多个字符 这个字段所属的表与这个字段关联的表,一定是一对多的关系 比如
- 在编程过程中,多了解语言周边的一些知识,以及一些技巧,可以让你加速成为一个优秀的程序员。对于Python程序员,你需要注意一下本文所提到的这
- memcache 的工作就是在专门的机器的内存里维护一张巨大的hash表,来存储经常被读写的一些数组与文件,从而极大的提高网站的运行效率,减
- 使用shutil.move(src, dst),src为要移动的文件的路径,dst为目的路径,路径必须是绝对路径import osimpor
- Python3中print函数的换行最近看了看Python的应用,从入门级的九九乘法表开始,结果发现Python3.x和Python2.x真
- #!/usr/bin/env python# -*- coding: utf-8 -*-from tkinter import *impor
- 简介模板方法模式,是行为型的设计模式。定义一个操作中的算法的骨架,而将一些步骤延迟到子类当中,使得子类可以不改变一个算法的结构即可重新定义该
- Django自带强大的User系统,为我们提供用户认证、权限、组等一系列功能,可以快速建立一个完整的后台功能。但User模型并不能满足我们的