pytorch 如何自定义卷积核权值参数
作者:Mr.Jcak 发布时间:2021-10-30 19:10:22
标签:pytorch,卷积核,权值,参数
pytorch中构建卷积层一般使用nn.Conv2d方法,有些情况下我们需要自定义卷积核的权值weight,而nn.Conv2d中的卷积参数是不允许自定义的,此时可以使用torch.nn.functional.conv2d简称F.conv2d
torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)
F.conv2d可以自己输入且也必须要求自己输入卷积权值weight和偏置bias。因此,构建自己想要的卷积核参数,再输入F.conv2d即可。
下面是一个用F.conv2d构建卷积层的例子
这里为了网络模型需要写成了一个类:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.weight = nn.Parameter(torch.randn(16, 1, 5, 5)) # 自定义的权值
self.bias = nn.Parameter(torch.randn(16)) # 自定义的偏置
def forward(self, x):
x = x.view(x.size(0), -1)
out = F.conv2d(x, self.weight, self.bias, stride=1, padding=0)
return out
值得注意的是,pytorch中各层需要训练的权重的数据类型设为nn.Parameter,而不是Tensor或者Variable。parameter的require_grad默认设置为true,而Varaible默认设置为False。
补充:pytorch中卷积参数的理解
kernel_size代表着卷积核,例如kernel_size=3或kernel_size=(3,7);
stride
:表明卷积核在像素级图像上行走的步长,如图2,步长为1;
padding
:为上下左右填充的大小,例如padding=0/1/(1,1)/(1,3),
padding=0 不填充;
padding=1/(1,1) 上下左右分别填充1个格;
padding=(1,3) 高(上下)填充2个格,宽(左右)填充6个格;
卷积代码
torch.nn.Conv2d(512,512,kernel_size=(3,7),stride=2,padding=1)
指定输出形状的上采样
def upsample_add(self,x,y):
_,_,H,W = y.size()
return F.interpolate(x, size=(H,W), mode='bilinear', align_corners=False) + y
反卷积上采样
output_shape_w=kernel_size_w+(output_w-1)(kernel_size_w-1)+2padding
self.upscore2 = nn.ConvTranspose2d(
512, 1, kernel_size=3, stride=2,padding=0, bias=False)
来源:https://blog.csdn.net/weixin_38314865/article/details/105941140


猜你喜欢
- import os import sys import string #以指定模式打开指定文件,获取文件句柄 def getFileIns(
- 代码如下:<% '--------定义部份------------------ Dim XH_P
- 本文实例讲述了Python时间的精准正则匹配方法。分享给大家供大家参考,具体如下:要用正则表达式精准匹配时间,其实并不容易方式一:>&
- 相信大家一定碰到过,打开某个网页,却显示一堆像乱码,如"бЇЯАзЪСЯ"、"�????????"?
- 一、mock.js的使用mock.js的使用步骤① 下载依赖 npm install mock -d(开发环境使用)② 引入到main.js
- 一、os常用方法1.获取当前路径 os.getcwd()# coding:utf-8import osif __name__ == '
- 一、交换变量x = 6y = 5x, y = y, xprint x>>> 5print y>>> 6二
- 前言刚刚看了EuroPython 2017一篇演讲,Why You Don't Need Design Patterns in Py
- 今天接到一个小需求,就是想在windows环境下,上传压缩文件到linux指定的目录位置并且解压出来,然后我想了一下,这个可以用python
- strconv包该包主要实现基本数据类型与其字符串表示的转换。常用函数为Atoi()、Itia()、parse系列、format系列、app
- 有时候,通过一个名称来标识一个路由显得更方便一些,特别是在链接一个路由,或者是执行一些跳转的时候。你可以在创建 Router 实例的时候,在
- 记得以前的Windows任务定时是可以正常使用的,今天试了下,发现不能正常使用了,任务计划总是挂起。接下来记录下Python爬虫定时任务的几
- librosa是处理音频库里的opencv,使用python脚本研究音频,先安装三方库librosa。如下通过清华镜像源安装librosa;
- 1.java连接Oracle数据库使用以下代码三个步骤:下载ojdbc.jar包并导入项目中。将下面的代码放在你觉得它应该在的地方。修改代码
- 问题最近,在用SSH框架完成一个实践项目时,碰到了一个莫名其妙的Bug困扰了我好久,最后终于解决,记录如下。问题:同学在测试系统的时候突然发
- 今天在写 mysql 遇到一个比较特殊的问题。 mysql 语句如下: update wms_cabinet_form set cabf_e
- 前言go 当中的并发编程是通过goroutine来实现的,利用channel(管道)可以在协程之间传递数据,实现协程的协调与同步。使用新建一
- 1.进入官网https://www.python.org/,点击Downloads下的Windows按钮,进入下载页面。2.如下图所示,点击
- Codeigniter支持缓存技术,以达到最快的速度。尽管CI已经相当高效了,但是网页中的动态内容、主机的内存CPU和数据库读取速度等因素直
- 本文实例讲述了python基于右递归解决八皇后问题的方法。分享给大家供大家参考。具体分析如下:凡是线性回溯都可以归结为右递归的形式,也即是二