Tensorflow 卷积的梯度反向传播过程
作者:LQ6H 发布时间:2021-02-05 09:15:29
标签:Tensorflow,卷积,梯度,反向传播
一. valid卷积的梯度
我们分两种不同的情况讨论valid卷积的梯度:第一种情况,在已知卷积核的情况下,对未知张量求导(即对张量中每一个变量求导);第二种情况,在已知张量的情况下,对未知卷积核求导(即对卷积核中每一个变量求导)
1.已知卷积核,对未知张量求导
我们用一个简单的例子理解valid卷积的梯度反向传播。假设有一个3x3的未知张量x,以及已知的2x2的卷积核K
Tensorflow提供函数tf.nn.conv2d_backprop_input实现了valid卷积中对未知变量的求导,以上示例对应的代码如下:
import tensorflow as tf
# 卷积核
kernel=tf.constant(
[
[[[3]],[[4]]],
[[[5]],[[6]]]
]
,tf.float32
)
# 某一函数针对sigma的导数
out=tf.constant(
[
[
[[-1],[1]],
[[2],[-2]]
]
]
,tf.float32
)
# 针对未知变量的导数的方向计算
inputValue=tf.nn.conv2d_backprop_input((1,3,3,1),kernel,out,[1,1,1,1],'VALID')
session=tf.Session()
print(session.run(inputValue))
[[[[ -3.]
[ -1.]
[ 4.]]
[[ 1.]
[ 1.]
[ -2.]]
[[ 10.]
[ 2.]
[-12.]]]]
2.已知输入张量,对未知卷积核求导
假设已知3行3列的张量x和未知的2行2列的卷积核K
Tensorflow提供函数tf.nn.conv2d_backprop_filter实现valid卷积对未知卷积核的求导,以上示例的代码如下:
import tensorflow as tf
# 输入张量
x=tf.constant(
[
[
[[1],[2],[3]],
[[4],[5],[6]],
[[7],[8],[9]]
]
]
,tf.float32
)
# 某一个函数F对sigma的导数
partial_sigma=tf.constant(
[
[
[[-1],[-2]],
[[-3],[-4]]
]
]
,tf.float32
)
# 某一个函数F对卷积核k的导数
partial_sigma_k=tf.nn.conv2d_backprop_filter(x,(2,2,1,1),partial_sigma,[1,1,1,1],'VALID')
session=tf.Session()
print(session.run(partial_sigma_k))
[[[[-37.]]
[[-47.]]]
[[[-67.]]
[[-77.]]]]
二. same卷积的梯度
1.已知卷积核,对输入张量求导
假设有3行3列的已知张量x,2行2列的未知卷积核K
import tensorflow as tf
# 卷积核
kernel=tf.constant(
[
[[[3]],[[4]]],
[[[5]],[[6]]]
]
,tf.float32
)
# 某一函数针对sigma的导数
partial_sigma=tf.constant(
[
[
[[-1],[1],[3]],
[[2],[-2],[-4]],
[[-3],[4],[1]]
]
]
,tf.float32
)
# 针对未知变量的导数的方向计算
partial_x=tf.nn.conv2d_backprop_input((1,3,3,1),kernel,partial_sigma,[1,1,1,1],'SAME')
session=tf.Session()
print(session.run(inputValue))
[[[[ -3.]
[ -1.]
[ 4.]]
[[ 1.]
[ 1.]
[ -2.]]
[[ 10.]
[ 2.]
[-12.]]]]
2.已知输入张量,对未知卷积核求导
假设已知3行3列的张量x和未知的2行2列的卷积核K
import tensorflow as tf
# 卷积核
x=tf.constant(
[
[
[[1],[2],[3]],
[[4],[5],[6]],
[[7],[8],[9]]
]
]
,tf.float32
)
# 某一函数针对sigma的导数
partial_sigma=tf.constant(
[
[
[[-1],[-2],[1]],
[[-3],[-4],[2]],
[[-2],[1],[3]]
]
]
,tf.float32
)
# 针对未知变量的导数的方向计算
partial_sigma_k=tf.nn.conv2d_backprop_filter(x,(2,2,1,1),partial_sigma,[1,1,1,1],'SAME')
session=tf.Session()
print(session.run(partial_sigma_k))
[[[[ -1.]]
[[-54.]]]
[[[-43.]]
[[-77.]]]]
来源:https://www.cnblogs.com/LQ6H/p/10343262.html


猜你喜欢
- import numpy as npimport pandas as pdfrom pandas_datareader import dat
- 功能:扫描当前目录下所有CSV文件并对其中文件进行统计,输出统计值到CSV文件pip install pandasimport pandas
- python 将字典写为json文件字典结构如下res = { "data":[]}temp
- 本文实例讲述了Python操作Oracle数据库的简单方法和封装类。分享给大家供大家参考,具体如下:最近工作有接触到Oracle,发现很多地
- 本文实例讲述了PHP对象克隆clone用法。分享给大家供大家参考,具体如下:浅克隆:只是克隆对象中的非对象非资源数据,即对象中属性存储的是对
- 本文实例为大家分享了python读取Excel实例的具体代码,供大家参考,具体内容如下1.操作步骤:(1)安装python官方Excel库-
- 本文实例讲述了wxpython中自定义事件的实现与使用方法。分享给大家供大家参考,具体如下:创建自定义事件的步骤:① 定义事件类,该事件类必
- 安装pillow(python的图形界面库)第一种方法在Dos界面输入pip install pillow(但是不知为何总是失败);搞了好几
- 简介模拟登录淘宝已经不是一件新鲜的事情了,过去我曾经使用get/post方式进行爬虫,同时也加入IP代理池进行跳过检验,但随着大型网站的升级
- 环境搭建下载安 * eego,bee1.开启gomod设置代理go env -w GO111MODULE=ongo env -w GOPROX
- 本文实例讲述了Python基于opencv实现的简单画板功能。分享给大家供大家参考,具体如下:import cv2import numpy
- 1、控制"纵打"、 横打”和“页面的边距。 (1)<script defer> function SetPr
- 环境:编辑工具:浏览器:安装xlrd安装DDT一 分析1 目录结构2 导入包二 代码import xlrdcl
- 代码如下:<% set rs=server.createobject("adodb.recordset&
- 描述给定一个序列(至少含有 1 个数),从该序列中寻找一个连续的子序列,使得子序列的和最大。 例如,给定序列 [-2,1,-3,4,-1,2
- 1.策略模式(Strategy): 定义了算法家族, 分别封装起来, 让它们之间可以互相替换. 比如Collections.sort(Lis
- 本文实例讲述了Python实现按照指定要求逆序输出一个数字的方法。分享给大家供大家参考,具体如下:问题是:输入一个数字,按照指定要求逆序输出
- 前言在javascript中,我们都知道使用var来声明变量。javascript是函数级作用域,函数内可以访问函数外的变量,函数外不能访问
- 以下是通过Excel 的VBA连接Oracle并操作Oracle相关数据的示例Excel 通过VBA连接数据库需要安装相应的Oracle客户
- HTTP上传的文件的原理HTTP协议的文件上传是通过HTTP POST请求实现的,使用multipart/form-data格式将待上传的文