TensorFlow的权值更新方法
作者:朂嘼 发布时间:2022-12-24 21:41:08
一. MovingAverage权值滑动平均更新
1.1 示例代码:
def create_target_q_network(self,state_dim,action_dim,net):
state_input = tf.placeholder("float",[None,state_dim])
action_input = tf.placeholder("float",[None,action_dim])
ema = tf.train.ExponentialMovingAverage(decay=1-TAU)
target_update = ema.apply(net)
target_net = [ema.average(x) for x in net]
layer1 = tf.nn.relu(tf.matmul(state_input,target_net[0]) + target_net[1])
layer2 = tf.nn.relu(tf.matmul(layer1,target_net[2]) + tf.matmul(action_input,target_net[3]) + target_net[4])
q_value_output = tf.identity(tf.matmul(layer2,target_net[5]) + target_net[6])
return state_input,action_input,q_value_output,target_update
def update_target(self):
self.sess.run(self.target_update)
其中,TAU=0.001,net是原始网络(该示例代码来自DDPG算法,经过滑动更新后的target_net是目标网络 )
第一句 tf.train.ExponentialMovingAverage,创建一个权值滑动平均的实例;
第二句 apply创建所训练模型参数的一个复制品(shadow_variable),并对这个复制品增加一个保留权值滑动平均的op,函数average()或average_name()可以用来获取最终这个复制品(平滑后)的值的。
更新公式为:
shadow_variable = decay * shadow_variable + (1 - decay) * variable
在上述代码段中,target_net是shadow_variable,net是variable
1.2 tf.train.ExponentialMovingAverage.apply(var_list=None)
var_list必须是Variable或Tensor形式的列表。这个方法对var_list中所有元素创建一个复制,当其是Variable类型时,shadow_variable被初始化为variable的初值,当其是Tensor类型时,初始化为0,无偏。
函数返回一个进行权值平滑的op,因此更新目标网络时单独run这个函数就行。
1.3 tf.train.ExponentialMovingAverage.average(var)
用于获取var的滑动平均结果。
二. tf.train.Optimizer更新网络权值
2.1 tf.train.Optimizer
tf.train.Optimizer允许网络通过minimize()损失函数自动进行权值更新,此时tf.train.Optimizer.minimize()做了两件事:计算梯度,并把梯度自动更新到权值上。
此外,tensorflow也允许用户自己计算梯度,并做处理后应用给权值进行更新,此时分为以下三个步骤:
1.利用tf.train.Optimizer.compute_gradients计算梯度
2.对梯度进行自定义处理
3.利用tf.train.Optimizer.apply_gradients更新权值
tf.train.Optimizer.compute_gradients(loss, var_list=None, gate_gradients=1, aggregation_method=None, colocate_gradients_with_ops=False, grad_loss=None)
返回一个(梯度,权值)的列表对。
tf.train.Optimizer.apply_gradients(grads_and_vars, global_step=None, name=None)
返回一个更新权值的op,因此可以用它的返回值ret进行sess.run(ret)
2.2 其它
此外,tensorflow还提供了其它计算梯度的方法:
• tf.gradients(ys, xs, grad_ys=None, name='gradients', colocate_gradients_with_ops=False, gate_gradients=False, aggregation_method=None)
该函数计算ys在xs方向上的梯度,需要注意与train.compute_gradients所不同的地方是,该函数返回一组dydx dydx的列表,而不是梯度-权值对。
其中,gate_gradients是在ys方向上的初始梯度,个人理解可以看做是偏微分链式求导中所需要的。
• tf.stop_gradient(input, name=None)
该函数告知整个graph图中,对input不进行梯度计算,将其伪装成一个constant常量。比如,可以用在类似于DQN算法中的目标函数:
cost=|r+Q next −Q current | cost=|r+Qnext−Qcurrent|
可以事先声明
y=tf.stop_gradient(r+Q next r+Qnext)
来源:https://blog.csdn.net/GH234505/article/details/54976696
猜你喜欢
- SQL Server Sa用户相信大家都有一定的理解,下面就为您介绍SQL Server 2000身份验证模式的修改方法及SQL Serve
- 值类型和引用类型值类型:int、float、bool和string这些类型都属于值类型,使用这些类型的变量直接指向存在内存中的值,值类型的变
- 本文实例讲述了php+html5基于websocket实现聊天室的方法。分享给大家供大家参考。具体如下:html5的websocket 实现
- 代码如下: <% '屏蔽主流的下载工具 Dimxurl,xtool '获取浏览器AGENT xurl=lcase(Re
- 第一次碰到这个问题的时候,确实不知道该怎么办,后来请教了一个大神,加上自己的理解,才了解是什么意思,这个东西写python的会经常用到,而且
- 这篇论坛文章(赛迪网技术社区)主要介绍了一些特别有用但文档中没有介绍的sql server DBCC命令,详细内容请参考下文:以下是一些sq
- python3.7简单的爬虫,具体代码如下所示:#https://www.runoob.com/w3cnote/python-spider-
- 当子类继承父类后,需要调用父类的方法和属性时,需要调用父类的初始化函数。class A(object): def __init_
- 本文实例为大家分享了Python实现井字棋小游戏的具体代码,供大家参考,具体内容如下import osdef print_board(boa
- Mysqli是php5之后才有的功能,没有开启扩展的朋友可以打开您的php.ini的配置文件。 查找下面的语句:;extension=php
- 概述先来介绍一下xml格式的文件,从数据分析的角度去看xml格式的数据集,具有以下的优点开放性(能在任何平台上读取和处理数据,允许通过一些网
- 1.python解释器安装下载地址:https://www.python.org/打开官网,点击downloads,选择操作系统,以wind
- 大家可能经常会遇到这种情况:sql="select * from table"set rs=conn.execute(s
- A 定义数组有两种方式:DIM和REDIM。DIM定义的是固定个数、数据类型的数组;而REDIM则不同,它可以定义不同类型的数据,也可以定义
- 运行以下代码: Dim com As ADODB.Command Dim rst
- 摘要在这篇文章里,我将以反模式的角度来直接讨论Django的低级ORM查询方法的使用。作为一种替代方式,我们需要在包含业务逻辑的
- CSS+DIV是网站标准(或称“WEB标准”)中常用的术语之一,通常为了说明与HTML网页设计语言中的表格(table)定位方式的区别,因为
- 之前在一个web系统的设计中,和另一个设计师讨论,“保存”和“取消”按钮该怎么设计。我的观点是,保存是比取消更常用的按钮,也是用户的主要目的
- 首先安装pip install ruamel.yaml用于修改yaml文件#coding:utf-8from ruamel import y
- 用下列方法可以做到: main.htm<html><body><form action="