F.conv2d pytorch卷积计算方式
作者:wanghua609 发布时间:2021-06-26 07:38:41
F.conv2d pytorch卷积计算
Pytorch里一般小写的都是函数式的接口,相应的大写的是类式接口。
函数式的更加low-level一些,如果不需要做特别复杂的配置只需要用类式接口就够了。
可以这样理解
nn.Conved是2D卷积层,而F.conv2d是2D卷积操作。
import torch
from torch.nn import functional as F
"""手动定义卷积核(weight)和偏置"""
w = torch.rand(16, 3, 5, 5) # 16种3通道的5乘5卷积核
b = torch.rand(16) # 和卷积核种类数保持一致(不同通道共用一个bias)
"""定义输入样本"""
x = torch.randn(1, 3, 28, 28) # 1张3通道的28乘28的图像
"""2D卷积得到输出"""
out = F.conv2d(x, w, b, stride=1, padding=1) # 步长为1,外加1圈padding,即上下左右各补了1圈的0,
print(out.shape)
out = F.conv2d(x, w, b, stride=2, padding=2) # 步长为2,外加2圈padding
print(out.shape)
out = F.conv2d(x, w) # 步长为1,默认不padding, 不够的舍弃,所以对于28*28的图片来说,算完之后变成了24*24
print(out.shape)
在DSSINet发现又用到了空洞卷积dilated convolution
mu1 = F.conv2d(img1, window , padding=padd, dilation=dilation, groups=channel)
Dilated/Atrous convolution或者是convolution with holes从字面上就很好理解,是在标准的convolution map里注入空洞,以此来增加感受野reception field。
相比原来的正常卷积,空洞卷积多了一个超参数dilation rate,指的是kernel的间隔数量(正常的卷积是dilation rate=1)
正常图像的卷积为
空洞卷积为
现在我们再来看下卷积本身,并了解他背后的设计直觉,以下主要探讨空洞卷积在语义分割(semantic segmentation)的应用。
卷积的主要问题
1、up-sampling/pooling layer(e.g. bilinear interpolation) is deterministic(not learnable)
2、内部数据结构丢失,空间层级化信息丢失。
3、小物体信息无法重建(假设有4个pooling layer,则任何小于2^4=16 pixel的物体信息将理论上无法重建)
在这样问题的存在下,语义分割问题一直处于瓶颈期无法再明显提高精度,而dilated convolution 的设计就良好的避免了这些问题。
对于dilated convolution,我们已经可以发现他的优点,即内部数据结构的保留和避免使用down_sampling这样的特性。但是完全基于dilated convolution的结构如何设计则是一个新的问题。
pytorch中空洞卷积分为两类,一类是正常图像的卷积,另一类是池化时候。
空洞卷积的目的是为了在扩大感受野的同时,不降低图片分辨率和不引入额外参数及计算量(一般在CNN中扩大感受野都需要使用S》1的conv或者pooling,导致分辨率降低,不利于segmentation,如果使用大卷积核,确实可以达到增大感受野,但是会引入额外的参数及计算量)。
F.Conv2d和nn.Conv2d
import torch
import torch.nn.functional as F
# 小括号里面有几个[]就代表是几维数据
input = torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
kernel = torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
input = torch.reshape(input,(1,1,5,5))
kernel = torch.reshape(kernel,(1,1,3,3))
# stride代表的是步长的意思,即每次卷积核向左或者向下移动多少步进行相乘
# 因为conv2d的input和weight对应的tensor是[batch,channel,h,w],所以上述才将它们进行reshape
output = F.conv2d(input,kernel,stride=1)
print(output)
output = F.conv2d(input,kernel,stride=2)
print(output)
# padding代表的是向上下左右填充的行列数,里面数字填写0
output3 = F.conv2d(input,kernel,stride=1,padding=1)
print(output3)
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Conv2d
from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10('./torchvision_dataset', train=False, download=False,
transform=torchvision.transforms.ToTensor())
# 准备好数据集就放在dataloader中进行加载
dataloader = DataLoader(dataset, batch_size=64)
# 开始定义一个卷积类
class Zkl(nn.Module):
def __init__(self):
super(Zkl, self).__init__()
self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
def forward(self,x):
x = self.conv1(x)
return x
writer = SummaryWriter("nn_conv2d")
zkl = Zkl()
# print(zkl)
step = 0
for data in dataloader:
imgs,target = data
output = zkl(imgs)
#print(imgs.shape)
#print(output.shape)
writer.add_images('nn_conv2d_input',imgs,step)
#因为输出是6个通道,tensorboard无法解析,所以需要reshape三个通道
output = torch.reshape(output,(-1,3,30,30))
writer.add_images('nn_conv2d_output',output,step)
step+=1
writer.close()
来源:https://blog.csdn.net/weixin_38145317/article/details/104923015


猜你喜欢
- 1.linux下启动mysql的命令:mysqladmin start/ect/init.d/mysql start (前面为mysql的安
- Golang交叉编译平台的二进制文件熟悉golang的人都知道,golang交叉编译很简单的,只要设置几个环境变量就可以了# mac上编译l
- 这三种情况下所得到的server.MapPath是一致的,这就导致上传之后写入数据库的图片地址和实际图片存储地址不一致,因此,我们需要自定义
- 前言一年一度的虐狗节终于过去了,朋友圈各种晒,晒自拍,晒娃,晒美食,秀恩爱的。程序员在晒什么,程序员在加班。但是礼物还是少不了的,送什么好?
- CREATE DATABASE `ct` DEFAULT CHARACTER SET utf8 COLLATE utf8_general_c
- xml文件:country.xml<data><country name="shdi2hajk">
- Python 中的 timeit 模块可以用来测试一段代码的执行耗时,如一个变量赋值语句的执行时间,一个函数的运行时间等。timeit 模块
- 需要在程序中使用二维数组,网上找到一种这样的用法: #创建一个宽度为3,高度为4的数组#[[0,0,0], # [0,0,0],#
- 下列语句部分是Mssql语句,不可以在access中使用。 SQL分类: DDL—数据定义语言(CREATE,ALTER,DROP,DECL
- 一.什么是事务在MySQL中的事务(Transaction)是由存储引擎实现的,在MySQL中,只有InnoDB存储引擎才支持事务。事务处理
- 本文实例讲述了Python实现基本数据结构中栈的操作。分享给大家供大家参考,具体如下:#! /usr/bin/env python#codi
- 通过使用zabbix 日志监控 我发现一个问题 例如oracle的日志有报错的情况 ,通常不会去手动清理 这样的话当第二次有日志写进来的时候
- 1、什么是触发器 触发器对表进行插入、更新、删除的时候会自动执行的特殊存储过程。触发器一般用在check
- 一、目录权限设置很重要:可以有效防范黑客上传木马文件. 如果通过 chmod 644 * -R 的话,php文件就没有权限访问了。 如果通过
- Javascript 选择器(selector engine)似乎从 jQuery 流行以来就大行其道,改变了原有 Javascript 选
- 第一种方法import pandas as pdfrom collections import Counterdata = '参赛信
- SQL Server中的集合运算包括UNION(合并),EXCEPT(差集)和INTERSECT(相交)三种。集合运算的基本使用1.UNIO
- 功能是从客户端向服务发送一个字符串, 服务器收到后将字符串重新发送给客户端,同时,在连接建立之后,服务器可以向客户端发送任意多的字符串客户端
- 一、动机最近打算折腾vn.py,但只有py27版本的,因为一向习惯使用最新稳定版的,所以不得不装py27的环境,不得不说 Python的全局
- Vue开发环境跨域访问其他服务器或者本机其他端口,需要配置项目中config/index.js文件,修改如下module.exports =