PyTorch中反卷积的用法详解
作者:月牙眼的楼下小黑 发布时间:2022-09-21 18:12:34
pytorch中的 2D 卷积层 和 2D 反卷积层 函数分别如下:
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=True)
class torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, bias=True)
我不禁有疑问:
问题1: 两个函数的参数为什么几乎一致呢?
问题2: 反卷积层中的 output_padding是什么意思呢?
问题3: 反卷积层如何计算input和output的形状关系呢?
看了中文文档后,我得不出答案,看了英文文档,才弄明白了。花费了一个下午的时间去研究这个问题,值得用此文纪录一下。
我们知道,在卷积层中,输入输出的形状关系为:
o = [ (i + 2p - k)/s ] +1 (1)
其中:
O : 为 output size
i: 为 input size
p: 为 padding size
k: 为kernel size
s: 为 stride size
[] 为下取整运算
(1) 当 S=1 时
若 s等于1,则公式(1)中的取整符号消失,o 与 i 为 一一对应 的关系。 我们有结论:
如果卷积层函数和反卷积层函数的 kernel_size, padding size参数相同(且 stride= 1),设反卷基层的输入输出形状为 i' 和 o', 卷积层的输入输出形状i和o, 则它们为 交叉对应 的关系,即:
i = o'
o = i'
为回答问题3, 我们将上述关系代入公式中,即:
i' = o' + 2p - k +1
已知 i', 即可推出 o':
o' = i' - 2p + k - 1 (2)
摘两个例子:
(2) 当 S>1 时
若 S>1 , 则公式(1)中的取整符号不能消去,o 与 i 为 多对1 的关系。 效仿 S=1时的情形, 我们有结论:
如果卷积层函数和反卷积层函数的 kernel_size, padding size参数相同(且 stride>1),设反卷基层的输入输出形状为 i' 和 o', 卷积层的输入输出形状i和o,
i' = [ (o' + 2p - k)/s ] +1
已知 i', 我们可以得出 s 个 o' 解:
o'(0) = ( i' - 1) x s + k - 2p
o'(1) = o'(1) + 1
o'(2) = o'(1) + 2
...
o'(s-1) = o'(1) + s-1
即:
o'(n) =o'(1) + n = ( i' - 1) x s + k - 2p + n,
n = {0, 1, 2...s-1}
为了确定唯一的 o' 解, 我们用反卷积层函数中的ouput padding参数指定公式中的 n 值。这样,我们就回答了问题(2)。
摘一个简单的例子:
(3) 实验验证
给出一小段测试代码,改变各个参数值,运行比较来验证上面得出的结论,have fun~.
from torch import nn
from torch.nn import init
from torch.autograd import Variable
dconv = nn.ConvTranspose2d(in_channels=1, out_channels= 1, kernel_size=2, stride=2, padding=1,output_padding=0, bias= False)
init.constant(dconv.weight, 1)
print(dconv.weight)
input = Variable(torch.ones(1, 1, 2, 2))
print(input)
print(dconv(input))
来源:https://www.jianshu.com/p/01577e86e506
猜你喜欢
- 解决方案1.安装django-cors-headerspip install django-cors-headers2.配置settings
- 原始数据TS PERIOD REQUEST STEPPED VALUE STATUS SECONDS20-DEC-16 00:00:00.0
- 从过往MySQL数据库生产环境的维护工作中,总结的一些小经验和知识,未必有多深奥,但是对我们消除隐患,确保MySQL数据库生产环境四个9的作
- 认识pip众所周知,pip可以对python的第三方库进行安装、更新、卸载等操作,十分方便。pip的全称:package installer
- 前言我们先说一下思路:先对目标网站发送请求,获取html源码,然后对源码里面的所以图片链接进行筛选,然后再次对图片链接发送请求,然后保存。思
- 声明定位元素:position属性值设置除默认值static以外的元素,包括relative,absolute,fixed。平台:win/I
- 使用python自带的itertools模块调用其product函数传入我们想组合生成的字符数据便会源源不断的生成组合而且不会重复repea
- 通常在多个不等式的时候,需要分着写,比如x = 1if x>0 and x<3: print(True)但是在Python中居然
- 在使用出colab进行模型训练时,发现colab的python版本更新为了3.7.11,而我的代码要在python3.6下才行配置好环境,于
- 1、通过复制数据构造张量1.1 torch.tensor()torch.tensor([[0.1, 1.2], [2.2, 3.1], [4
- ord是unicode ordinal的缩写,即编号chr是character的缩写,即字符ord和chr是互相对应转换的.但是由于chr局
- 成天都要与样式打交道的朋友,相信对CSS选择符(CSS Selectors)都不会陌生。不过对于刚接触或者还不是很熟悉css的朋友来说,能够
- 本文实例讲述了Python模块的定义,模块的导入,__name__用法。分享给大家供大家参考,具体如下:相关内容:什么是模块模块的导入同级目
- 序列是Python中最基本的数据结构。序列中的每个元素都分配一个数字 - 它的位置,或索引,第一个索引是0,第二个索引是1,依此类推。Pyt
- 在python中利用numpy array进行数据处理,经常需要找出符合某些要求的数据位置,有时候还需要对这些位置重新赋值。这里总结了几种找
- 本文实例为大家分享了python比特币初始配置的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*- "
- 最近和一程序员合作项目。弄的我头都大了~埋怨我的CSS命名看不懂~得按照他的来。结果我打开他的页面,看了看,从头第一个开始就是content
- 项目场景:Python版本:3.8因公司业务需求,须开发一套局域网内视频会议软件,此次采用Python实现此功能。程序编写完并在编译器实现此
- 如下所示:import urllib.requestimport urllib.parseurl = 'https://weibo.
- 本文实例讲述了php+html5基于websocket实现聊天室的方法。分享给大家供大家参考。具体如下:html5的websocket 实现