Pytorch平均池化nn.AvgPool2d()使用方法实例
作者:Cassiel_cx 发布时间:2023-09-30 02:49:35
【pytorch官方文档】:https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html?highlight=avgpool2d#torch.nn.AvgPool2d
torch.nn.AvgPool2d()
作用
在由多通道组成的输入特征中进行2D平均池化计算
函数
torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)
参数
Args:
kernel_size: 滑窗(池化核)大小
stride: 滑窗的移动步长, 默认值为kernel_size
padding: 在输入信号两侧的隐式零填充数量
ceil_mode: 决定计算输出的形状时是向上取整还是向下取整, 默认为False(向下取整)
count_include_pad: 在平均池化计算中是否包含零填充, 默认为True(包含零填充)
divisor_override: 如果指定了, 它将被作为平均池化计算中的除数, 否则将使用池化区域的大小作为平均池化计算的除数
公式
代码实例
假设输入特征为S,输出特征为D
情况一
ceil_mode=False, count_include_pad=True(计算时包含零填充)
import torch
import torch.nn as nn
import numpy as np
# 生成一个形状为1*1*3*3的张量
x1 = np.array([
[1,2,3],
[4,5,6],
[7,8,9]
])
x1 = torch.from_numpy(x1).float()
x1 = x1.unsqueeze(0).unsqueeze(0)
# 实例化二维平均池化
avgpool1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=True)
y1 = avgpool1(x1)
print(y1)
# 打印结果
'''
tensor([[[[1.3333, 1.7778],
[2.6667, 3.1111]]]])
'''
计算过程:
输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,
D[1,1] = (0+0+0+0+1+2+0+4+5) / 9 = 1.3333,
D[1,2] = (0+0+0+2+3+0+5+6+0) / 9 = 1.7778,
D[2,1] = (0+4+5+0+7+8+0+0+0) / 9 = 2.6667,
D[2,2] = (5+6+0+8+9+0+0+0+0) / 9 = 3.1111.
情况二
ceil_mode=False, count_include_pad=False(计算时不包含零填充)
avgpool2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False)
y2 = avgpool2(x1)
print(y2)
# 打印结果
'''
tensor([[[[3., 4.],
[6., 7.]]]])
'''
计算过程:
输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,
D[1,1] = (1+2+4+5) / 4 = 3,
D[1,2] = (2+3+5+6) / 4 = 4,
D[2,1] = (4+5+7+8) / 4 = 6,
D[2,2] = (5+6+8+9) / 4 = 7.
情况三
ceil_mode=False, count_include_pad=False, divisor_override=2(将计算平均池化时的除数指定为2)
avgpool3 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False, count_include_pad=False, divisor_override=2)
y3 = avgpool3(x1)
print(y3)
# 打印结果
'''
tensor([[[[ 6., 8.],
[12., 14.]]]])
'''
计算过程:
输出形状= floor[(3 - 3 + 2) / 2] + 1 = 2,
D[1,1] = (1+2+4+5) / 2 = 6,
D[1,2] = (2+3+5+6) / 2 = 8,
D[2,1] = (4+5+7+8) / 2 = 12,
D[2,2] = (5+6+8+9) / 2 = 14.
情况四
ceil_mode=True, count_include_pad=True, divisor_override=None(在计算输出的形状时向上取整)
x2 = np.array([
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
])
x2 = torch.from_numpy(x2).reshape(1,1,4,4).float()
avgpool4 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True)
y4 = avgpool4(x2)
print(y4)
# 打印结果
'''
tensor([[[[ 1.5556, 3.3333, 2.0000],
[ 6.3333, 11.0000, 6.0000],
[ 4.5000, 7.5000, 4.0000]]]])
'''
计算过程:
输出形状 = ceil[(4 - 3 + 2) / 2] + 1 = 3,
D[1,1] = (0+0+0+0+1+2+0+5+6) / 9 = 1.5556,
D[1,2] = (0+0+0+2+3+4+6+7+8) / 9 = 3.3333,
D[1,3] = (0+0+4+0+8+0) / 6 = 2,
D[2,1] = (0+5+6+0+9+10+0+13+14) / 9 = 6.3333,
D[2,2] = (6+7+8+10+11+12+14+15+16) / 9 = 11,
D[2,3] = (8+0+12+0+16+0) / 6 = 6,
D[3,1] = (0+13+14+0+0+0) / 6 = 4.5,
D[3,2] = (14+15+16+0+0+0) / 6 = 7.5,
D[3,3] = (16+0+0+0) / 4 = 4.
来源:https://blog.csdn.net/qq_38964360/article/details/129148451


猜你喜欢
- 我就废话不多说了,大家还是直接看代码吧!import kerasimport numpy as npfrom keras.applicati
- 因为需要一个html形式的数据统计界面,所以做了一个基于pyecharts包的可视化程序,当然matplotlib还是常用的数据可视化包,只
- 感谢人类方方面面的创新,今天Web开发已经不需要在如何设计网站上面浪费时间了。框架和库帮助web开发者得以专注于真正的开发工作上。下面的这些
- 所见即所得的文本编辑器目前在网上流传的已经有很多了,并且都比较优秀,就我个人而言,用过的有以下几个: ·
- 1、计算器功能介绍可以实现数据的加(+),减(-),乘(*),除(/),取余运算(%),以及实现数据的删除(Del)和清空功能(C)。2、计
- 本文实例讲述了Go语言正则表达式。分享给大家供大家参考,具体如下:package mainimport "bytes"i
- chr()函数与ord()函数解析chr()函数用一个范围在 range(256)内的(就是0~255)整数作参数,返回一个对应的字符。返回
- 本文实例讲述了Go语言使用HTTP包创建WEB服务器的方法。分享给大家供大家参考,具体如下:在Golang中写一个http web服务器大致
- 首先来看实例代码:# -*- coding:utf-8 -*-import requestsimport datetimeimport ti
- 在web运行中很重要的一个功能就是加载静态文件,在django中可能已经给我们设置好了,我们只要直接把模板文件放在templates就好了,
- 有时候,依赖 vue 响应方式来更新数据是不够的,相反,我们需要手动重新渲染组件来更新数据。或者,我们可能只想抛开当前的
- OK,首先写一个python socket的server段,对开放三个端口:10000,10001,10002.krondo的例子中是每个s
- 本文实例为大家分享了python实现记事本功能的具体代码,供大家参考,具体内容如下1. 案例介绍tkinter 是 Python下面向 tk
- 可编辑下拉框-HTML <div style="position:relative;"> <selec
- 听歌识曲,顾名思义,用设备“听”歌曲,然后它要告诉你这是首什么歌。而且十之八九它还得把这首歌给你播放出来。这样的功能在QQ音乐等应用上早就出
- 一个网站空间,但是却可以实现多个域名的访问的一段ASP代码:<%if Request.ServerVariables("SE
- 简单的学习下利用socket来建立客户端和服务端之间的连接并且发送数据1. 客户端socketClient.py代码import socke
- 本文介绍用python实现的搜索本地文本文件内容的小程序。从而学习Python I/O方面的知识。代码如下:import os#根据文件扩展
- 本文实例为大家分享了python实现定时发送邮件到指定邮箱的具体代码,供大家参考,具体内容如下整个链路:传感器采集端采集数据,边缘端上传数据
- 本文实例为大家分享了python实现简易学生信息管理系统的具体代码,供大家参考,具体内容如下一、系统功能1.录入学生信息2.查找学生信息3.