PyTorch基础之torch.nn.Conv2d中自定义权重问题
作者:gy笨瓜 发布时间:2023-10-13 05:16:31
torch.nn.Conv2d中自定义权重
torch.nn.Conv2d函数调用后会自动初始化weight和bias,本文主要涉及
如何自定义weight和bias为需要的数均分布类型:
torch.nn.Conv2d.weight.data以及torch.nn.Conv2d.bias.data为torch.tensor类型,因此只要对这两个属性进行操作即可。
【sample】
以input_channels = 2, output_channels = 1 为例
In [1]: import torch
In [2]: import torch.nn as nn
In [3]: conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=3)
# 此时weight以及bias已由nn.Conv2d初始化
conv.weight, conv.bias
Out[4]:
(Parameter containing:
tensor([[[[-0.0335, 0.0855, -0.0708],
[-0.1672, 0.0902, -0.0077],
[-0.0838, -0.1539, -0.0933]],
[[-0.0496, 0.1807, -0.1477],
[ 0.0397, 0.1963, 0.0932],
[-0.2018, -0.0436, 0.1971]]]], requires_grad=True),
Parameter containing:
tensor([-0.1963], requires_grad=True))
# 手动设定
# conv.weight.data 以及 conv.bias.data属性为torch.tensor
# 因此只要获取conv.weight.data以及conv.bias.data属性,后续调用torch.tensor的不同方法即可进行修改
# 例如:全部修改为0
In [5]: conv.weight.data.zero_(), conv.bias.data.zero_()
In [6]: conv.weight, conv.bias
Out[6]:
(Parameter containing:
tensor([[[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]], requires_grad=True),
Parameter containing:
tensor([0.], requires_grad=True))
torch.nn.Conv2d()用法讲解
本文是深度学习框架 pytorch 的API : torch.nn.Conv2d() 函数的用法。介绍了 torch.nn.Conv2d() 各个参数的含义和用法,学会使用 pytorch 创建 卷积神经网络。
用法
Conv2d(in_channels, out_channels, kernel_size, stride=1,padding=0, dilation=1, groups=1,bias=True, padding_mode=‘zeros')
参数
in_channels
:输入的通道数目 【必选】out_channels
:输出的通道数目 【必选】kernel_size
:卷积核的大小,类型为int 或者元组,当卷积是方形的时候,只需要一个整数边长即可,卷积不是方形,要输入一个元组表示 高和宽。【必选】stride
:卷积每次滑动的步长为多少,默认是 1 【可选】padding
:设置在所有边界增加 值为 0 的边距的大小(也就是在feature map 外围增加几圈 0 ),例如当 padding =1 的时候,如果原来大小为 3 × 3 ,那么之后的大小为 5 × 5 。即在外围加了一圈 0 。【可选】dilation
:控制卷积核之间的间距(什么玩意?请看例子)【可选】
如果我们设置的dilation=0的话,效果如图:(蓝色为输入,绿色为输出,卷积核为3 × 3)
如果设置的是dilation=1,那么效果如图:(蓝色为输入,绿色为输出,卷积核仍为 3 × 3 。)
但是这里卷积核点与输入之间距离为1的值相乘来得到输出。
groups
:控制输入和输出之间的连接。(不常用)【可选】
举例来说:
比如 groups 为1,那么所有的输入都会连接到所有输出
当 groups 为 2的时候,相当于将输入分为两组,并排放置两层,每层看到一半的输入通道并产生一半的输出通道,并且两者都是串联在一起的。这也是参数字面的意思:“组” 的含义。
需要注意的是,in_channels 和 out_channels 必须都可以整除 groups,否则会报错(因为要分成这么多组啊,除不开你让人家程序怎么办?)
bias
: 是否将一个 学习到的 bias 增加输出中,默认是 True 。【可选】padding_mode
: 字符串类型,接收的字符串只有 “zeros” 和 “circular”。【可选】
注意:参数 kernel_size,stride,padding,dilation 都可以是一个整数或者是一个元组,一个值的情况将会同时作用于高和宽 两个维度,两个值的元组情况代表分别作用于 高 和 宽 维度。
相关形状
示例
入门学习者请不要过度关注某一些细节,建立一个简单的卷积层使用这个 API 其实很简单,大部分参数保持默认值就好,下面是简单的一个示例,创建一个简单的卷积神经网络:
class CNN(nn.Module):
def __init__(self,in_channels:int,out_channels:int):
"""
创建一个卷积神经网络
网络只有两层
:param in_channels: 输入通道数量
:param out_channels: 输出通道数量
"""
super(CNN).__init__()
self.conv1=nn.Conv2d(in_channels,10,3,stride=1,padding=1)
self.pool1=nn.MaxPool2d(kernel_size=2,stride=1)
self.conv2=nn.Conv2d(10,out_channels,3,stride=1,padding=1)
self.pool2=nn.MaxPool2d(kernel_size=2,stride=1)
def forward(self,x):
"""
前向传播函数
:param x: 输入,tensor 类型
:return: 返回结果
"""
out=self.conv1(x)
out=self.pool1(out)
out=self.conv2(out)
out=self.pool2(out)
return out
来源:https://blog.csdn.net/u012633319/article/details/109271370


猜你喜欢
- 表的创建CREATE TABLE `lee` (`id` int(10) NOT NULL AUTO_INCREMENT, `name` c
- 下面是我们插入到这个tuangou表的数据: id web city type 1 拉手网 北京 餐饮美食 2 拉手网 上海 休闲娱乐 3
- 首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层
- 在Windows下使用VSCode编译运行,都出现中文乱码的问题,今天我就遇见了这种情况,上网搜了半天也没有找到正确的解决方法,现将我把我的
- 2020年11月22日最新分享pycharm激活码,这次分享的pycharm激活码适用pycharm2020最新版及pycharm2019、
- django orm 有个defer方法,指定模型排除的字段。如下返回的Queryset, 排除‘username', 'i
- 我就废话不多说了,直接上代码!from enum import Enumclass Values(): values={'
- 使用windows API使用PIL中的ImageGrab模块下面对两者的特点和用法进行详细解释。一、Python调用windows API
- 本文实例讲述了python简单实现基于SSL的 IRC bot。分享给大家供大家参考。具体如下:#!/usr/bin/python# -*-
- 利用numpy、matplotlib、sympy绘制sigmoid、tanh、ReLU、leaky ReLU、softMax函数起因:深度学
- MFCC梅尔倒谱系数(Mel-scaleFrequency Cepstral Coefficients,简称MFCC)。MFCC通常有以下之
- 最简单的数组合并我们只要使用array_merge即可array_merge()将两个或多个数组的单元合并起来,一个数组中的值附加在前一个数
- append()方法追加传递obj到现有的列表。语法以下是append()方法的语法:list.append(obj)参数&nb
- datetime日期时间类,主要熟悉API,时区的概念与语言无关。from datetime import datetime as dtdt
- 扩展名在写Python程序时我们常见的扩展名是py, pyc,其实还有其他几种扩展名。下面是几种扩展名的用法。pypy就是最基本的源码扩展名
- 本文测试环境:CentOS 7 64-bit Minimal MySQL 5.7配置 yum 源在 https://dev.mysql.co
- 一、临时表实现分步处理1.概述当需要的结果需要经过多次处理后才能最终得到我们需要的结果时,就可以使用临时表,这里临时表就起到了一个中间处理的
- 很多网站在注册时除了需要用户填写用户名与密码之外,还会要求用户输入邮箱,而且是属于那种不填写就不能完成注册的强制型的。碰到这种情况的时候,一
- 函数重载的替代方法-伪重载,下面看一个具体的实例代码。<? php//函数重载的替代方法-伪重载////确实,在PHP中没有函数重载这
- MySQL使用环境变量TMPDIR的值作为保存临时文件的目录的路径名。如果未设置TMPDIR,MySQL将使用系统的默认值,通常为/tmp、