Tensorflow的梯度异步更新示例
作者:supe_king 发布时间:2022-10-04 15:08:23
标签:Tensorflow,梯度,异步,更新
背景:
先说一下应用吧,一般我们进行网络训练时,都有一个batchsize设置,也就是一个batch一个batch的更新梯度,能有这个batch的前提是这个batch中所有的图片的大小一致,这样才能组成一个placeholder。那么若一个网络对图片的输入没有要求,任意尺寸的都可以,但是我们又想一个batch一个batch的更新梯度怎么办呢?
操作如下:
先计算梯度:
# 模型部分
Optimizer = tf.train.GradientDescentOptimizer(1)
gradient = Optimizer.compute_gradients(loss) # 每次计算所有变量的梯度
grads_holder = [(tf.placeholder(tf.float32, shape=g.get_shape()), v) for (g, v) in gradient]# 将每次计算的梯度保存
optm = Optimizer.apply_gradients(grads_holder) # 进行梯度更新
# 初始化部分
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
# 实际训练部分
grads = [] # 定义一个空的列表用于存储每次计算的梯度
for i in range(batchsize): # batchsize设置在这里
x_i = ... # 输入
y_real = ... # 标签
grad_i = sess.run(gradient, feed_dict={inputs: x_i, outputs: y_real}) #梯度计算
grads.append(grad_i) # 梯度存储
# 定义一个空的字典用于存储,batchsize中所有梯度的和
grads_sum = {}
# 将网络中每个需要更新梯度的变量都遍历一遍
for i in range(len(grads_holder)):
k = grads_holder[i][0] # 得到该变量名
# 将该变量名下的所有梯度求和,这里也可以求平均,求平均只需要除以batchsize
grads_sum[k] = sum([g[i][0] for g in grads])
# 完成梯度更新
sess.run(optm,feed_dict=grads_sum)
来源:https://blog.csdn.net/supe_king/article/details/78017429


猜你喜欢
- 在网站开发的时候经常要用chr(),但本人比较懒没时间记那么多。于是到用到的时候就查,这样麻烦。现在将它写出来方便以后用到查,也方便大家!c
- python的pdb调试命令的命令整理及实例一、命令整理pdb调试命令完整命令简写命令描述argsa打印当前函数的参数breakb设置断点c
- 本文实例为大家分享了vue实现百度搜索功能的具体代码,供大家参考,具体内容如下最终效果:Baidusearch.vue所有代码:<te
- 使用Django意味着后台框架的几乎所有内容都会和Django产生互动,排除功能全部手撸的情况.Django 后台admin有大量的属性和方
- 来炫耀一下,谁看得懂我写的加密算法写了一整天了,这个代码用于ajax提交,要求就是加密后内容不能变得过长,加密解密需要效率高,至于安全性,被
- 矩阵增加行np.row_stack() 与 np.column_stack()import numpy as npa = np.array(
- 前言:一般处理数据使用的是pandas和numpy库,但是填充单元格颜色需要在excel中,使用的是openpyxl库,所以不能直接达到我们
- 配置文件如下,下面对配置文件进行一一解释"""Django settings for film1_manage
- 有关换行的问题首先提一个问题,如下。python程序代码如下:print("I'm Bob. What's you
- 这篇文章主要介绍了简单了解python装饰器原理及使用方法,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- Numpy中提供了concatenate,append, stack类(包括hsatck、vstack、dstack、row_stack、c
- 本文实例讲述了Python基于回溯法子集树模板解决数字组合问题。分享给大家供大家参考,具体如下:问题找出从自然数1、2、3、...、n中任取
- 绘制八个子图import matplotlib.pyplot as pltfig = plt.figure()shape=['.
- 起由:前一阵子想要刷一刷国二Python的题库,千方百计找到题库之后,打开一个个word文档,发现一题一题阅读很麻烦,而且答案就在题目的下面
- python join 和 split方法简单的说是:join用来连接字符串,split恰好相反,拆分字符串的。.join()join将 容
- 1 前言在 Java 和 js 中,lambda箭头函数是十分常见的操作,这种表达方式在使用时非常的简便。在python的语法中也有应用场景
- 前言:NumPy 是 Python 语言的一个扩充程序库,支持大量高维度数组与矩阵运算,此外也针对数组运算提供大量的数学函数库。同时NumP
- 书接上文用Python搓一个太阳系你们要的3D太阳系3体人真的存在吗太长不看版最小势能点在由两个大质量物体构成的重力系统中,有一些特殊的区域
- 对url进行编码在服务器端我们可以使用asp中的server.urlencode,很方便实现。如:<% ss="asp之家欢
- MySQL5.0版本的安装图解教程是给新手学习的,当前mysql5.0.96是最新的稳定版本。mysql 下载地址 https://www.