网络编程
位置:首页>> 网络编程>> Python编程>> pytorch SENet实现案例

pytorch SENet实现案例

作者:小伟db  发布时间:2021-03-27 05:14:23 

标签:pytorch,SENet

我就废话不多说了,大家还是直接看代码吧~


from torch import nn

class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
 super(SELayer, self).__init__()

//返回1X1大小的特征图,通道数不变
 self.avg_pool = nn.AdaptiveAvgPool2d(1)
 self.fc = nn.Sequential(
  nn.Linear(channel, channel // reduction, bias=False),
  nn.ReLU(inplace=True),
  nn.Linear(channel // reduction, channel, bias=False),
  nn.Sigmoid()
 )

def forward(self, x):
 b, c, _, _ = x.size()

//全局平均池化,batch和channel和原来一样保持不变
 y = self.avg_pool(x).view(b, c)

//全连接层+池化
 y = self.fc(y).view(b, c, 1, 1)

//和原特征图相乘
 return x * y.expand_as(x)

补充知识:pytorch 实现 SE Block

论文模块图

pytorch SENet实现案例

代码


import torch.nn as nn
class SE_Block(nn.Module):
def __init__(self, ch_in, reduction=16):
 super(SE_Block, self).__init__()
 self.avg_pool = nn.AdaptiveAvgPool2d(1)# 全局自适应池化
 self.fc = nn.Sequential(
  nn.Linear(ch_in, ch_in // reduction, bias=False),
  nn.ReLU(inplace=True),
  nn.Linear(ch_in // reduction, ch_in, bias=False),
  nn.Sigmoid()
 )

def forward(self, x):
 b, c, _, _ = x.size()
 y = self.avg_pool(x).view(b, c)
 y = self.fc(y).view(b, c, 1, 1)
 return x * y.expand_as(x)

现在还有许多关于SE的变形,但大都大同小异

来源:https://blog.csdn.net/qq_35985044/article/details/90142431

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com