TensorFlow实现指数衰减学习率的方法
作者:stepondust 发布时间:2021-02-20 13:13:55
在TensorFlow中,tf.train.exponential_decay函数实现了指数衰减学习率,通过这个函数,可以先使用较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。
tf.train.exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase, name)函数会指数级地减小学习率,它实现了以下代码的功能:
#tf.train.exponential_decay函数可以通过设置staircase参数选择不同的学习率衰减方式
#staircase参数为False(默认)时,选择连续衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step / decay_steps)
#staircase参数为True时,选择阶梯状衰减学习率:
decayed_learning_rate = learning_rate * math.pow(decay_rate, global_step // decay_steps)
①decayed_leaming_rate为每一轮优化时使用的学习率;
②leaming_rate为事先设定的初始学习率;
③decay_rate为衰减系数;
④global_step为当前训练的轮数;
⑤decay_steps为衰减速度,通常代表了完整的使用一遍训练数据所需要的迭代轮数,这个迭代轮数也就是总训练样本数除以每一个batch中的训练样本数,比如训练数据集的大小为128,每一个batch中样例的个数为8,那么decay_steps就为16。
当staircase参数设置为True,使用阶梯状衰减学习率时,代码的含义是每完整地过完一遍训练数据即每训练decay_steps轮,学习率就减小一次,这可以使得训练数据集中的所有数据对模型训练有相等的作用;当staircase参数设置为False,使用连续的衰减学习率时,不同的训练数据有不同的学习率,而当学习率减小时,对应的训练数据对模型训练结果的影响也就小了。
接下来看一看tf.train.exponential_decay函数应用的两种形态(省略部分代码):
①第一种形态,global_step作为变量被优化,在这种形态下,global_step是变量,在minimize函数中传入global_step将自动更新global_step参数(global_step每轮迭代自动加一),从而使得学习率也得到相应更新:
import tensorflow as tf
.
.
.
#设置学习率
global_step = tf.Variable(tf.constant(0))
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy, global_step=global_step)
.
.
.
#创建会话
with tf.Session() as sess:
.
.
.
for i in range(STEPS):
.
.
.
#通过选取的样本训练神经网络并更新参数
sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end]})
.
.
.
②第二种形态,global_step作为占位被feed,在这种形态下,global_step是占位,在调用sess.run(train_step)时使用当前迭代的轮数i进行feed:
import tensorflow as tf
.
.
.
#设置学习率
global_step = tf.placeholder(tf.float32, shape=())
learning_rate = tf.train.exponential_decay(0.01, global_step, 16, 0.96, staircase=True)
#定义反向传播算法的优化方法
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
.
.
.
#创建会话
with tf.Session() as sess:
.
.
.
for i in range(STEPS):
.
.
.
#通过选取的样本训练神经网络并更新参数
sess.run(train_step, feed_dict={x:X[start:end], y_:Y[start:end], global_step:i})
.
.
.
总结
以上所述是小编给大家介绍的TensorFlow实现指数衰减学习率的方法,希望对大家有所帮助!
来源:https://blog.csdn.net/qq_44009891/article/details/104171369


猜你喜欢
- 运行环境:IIS脚本语言:VBScript数据库:Access/SQL Server数据库语言:SQL1.概要:不论是在论坛,还是新闻系统,
- 要求:#出租车计费*************************************************************
- 【原文地址】New "Orcas" Language Feature: Extension Methods【原文发表日期
- 同由其他技术驱动的应用一样,在相同的Web服务器上运行Django应用也是可行的。 最简单直接的办法就是利用Apaches配置文件httpd
- 本文介绍TSV文件类型及其应用,同时介绍Golang语句读取TSV文件并转为struct的实现过程。认识TSV文件也许你之前不了解TSV文件
- 目录分区机制SELECT 查询INSERT 操作DELETE 操作UPDATE 操作分区的类型MySQL 的分区的实现方式是对数据表进行一层
- 本文实例讲述了Python实现批量修改文件名的方法。分享给大家供大家参考。具体如下:下载了评书《贺龙传奇》,文件名中却都含有xxx有声下载,
- 所以呢,在引用js文档的时候,要设置被引用的文档是什么编码的。 如:一个utf-8的页面引用一个gb2312的js文档,那么就要这么写 &l
- 本文实例讲述了Laravel5中实现模糊匹配加多条件查询功能的方法。分享给大家供大家参考,具体如下:方法1. ORM模式public fun
- code:f = open('yesterday','r',encoding='utf-8'
- (1)设计一个算法,确定两个矩形是否相交(即有重叠区域) (2)如果两个矩形相交,设计一个算法,求出相交的区域矩形 (1) 对于这个问题,一
- 我们在数据处理,往往不小心,pandas会“主动”加上行和列的名称,我现在就遇到了这个问题。这个是pandas中to_csv生成的数据各种拼
- 1.什么是Store?上一篇文章说了,Vuex就是提供一个仓库,Store仓库里面放了很多对象。其中state就是数据源存放地,对应于与一般
- Python版本 实现了比之前的xxftp更多更完善的功能 1、继续支持多用户 2、继续支持虚拟目录 3、增加支持用户根目录以及映射虚拟目录
- 这篇文章主要介绍了Python序列化与反序列化pickle用法实例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- Python 装饰器深入探讨在 Python 中,装饰器提供了一种简洁的方式,用来修改或增强函数和类的行为。装饰器在语法上表现为一个前置于函
- 多线程概述多线程使得程序内部可以分出多个线程来做多件事情,充分利用CPU空闲时间,提升处理效率。python提供了两个模块来实现多线程thr
- 我们假设TPCoins的发起人最初向已知客户 Dinesh 发出500个TPCoins.为此,他首先创建一个Dinesh
- 最近在做一个手机站,要求点击分享可以直接打开微信分享出去。而不是jiathis,share分享这种的点击出来二维码。在网上看了很多,都说AP
- 本文实例讲述了Django框架ORM数据库操作。分享给大家供大家参考,具体如下:测试数据:BookInfo表PeopleInfo表一.增加1