Python深度学习之使用Pytorch搭建ShuffleNetv2
作者:I 发布时间:2023-10-10 06:19:09
标签:Python,Pytorch,ShuffleNetv2
一、model.py
1.1 Channel Shuffle
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# reshape
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
1.2 block
class InvertedResidual(nn.Module):
def __init__(self, input_c: int, output_c: int, stride: int):
super(InvertedResidual, self).__init__()
if stride not in [1, 2]:
raise ValueError("illegal stride value.")
self.stride = stride
assert output_c % 2 == 0
branch_features = output_c // 2
# 当stride为1时,input_channel应该是branch_features的两倍
# python中 '<<' 是位运算,可理解为计算×2的快速方法
assert (self.stride != 1) or (input_c == branch_features << 1)
if self.stride == 2:
self.branch1 = nn.Sequential(
self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(input_c),
nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
@staticmethod
def depthwise_conv(input_c: int,
output_c: int,
kernel_s: int,
stride: int = 1,
padding: int = 0,
bias: bool = False) -> nn.Conv2d:
return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
stride=stride, padding=padding, bias=bias, groups=input_c)
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
1.3 shufflenet v2
class ShuffleNetV2(nn.Module):
def __init__(self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual):
super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels
# input RGB image
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(stage_names, stages_repeats,
self._stage_out_channels[1:]):
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # global pool
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
二、train.py
来源:https://blog.csdn.net/weixin_43154149/article/details/116267653
0
投稿
猜你喜欢
- 本文实例讲述了Python决策树之基于信息增益的特征选择。分享给大家供大家参考,具体如下:基于信息增益的特征选取是一种广泛使用在决策树(de
- 广州4.18书友会主题的内容提纲自己参与撰写,同时还参与组织和主持。通过这次的深入参与,我发现胡晓同学能坚持下来多不容易,先赞下。由于天公不
- 在自然语言处理(NLP)领域,文本相似度计算是一个常见的任务。本文将介绍如何使用Python计算文本之间的相似度,涵盖了余弦相似度、Jacc
- pycharm出现no module named xlwt问题首先声明,我是初学者,今天按照书上步骤,创建Excel文件,当我的xlwt安装
- 说明之前下载来zip包的漫画,里面的图片都是两张一起的:但是某些漫画查看软件不支持自动分屏,看起来会比较不舒服,所以只能自己动手来切分。操作
- 本文实例讲述了Python网络编程之TCP套接字简单用法。分享给大家供大家参考,具体如下:上学期学的计算机网络,因为之前还未学习python
- 作为一门脚本语言,写脚本时执行系统命令可以说很常见了,python提供了相关的模块和方法。os模块提供了访问操作系统服务的功能,由于涉及到操
- 在使用Python处理数据时,经常需要对数据筛选。这是在对时间筛选时,判断两列时间是否相差一年,如果是,则返回符合条件的所有列。data原始
- 一、问题原因(如果不是第一次使用pycharm,我觉得可以跳过这一章)我是升级以后,在用pycharm打开以前的项目就出现报错了;很明显是环
- Q: 不知xml和html有什么区别?它们不同在哪? A: 关于XML和HTML区别请参考: http://www.w3c.org/Mark
- 读写文件是最常见的IO操作。Python内置了读写文件的函数,用法和C是兼容的。读写文件前,我们先必须了解一下,在磁盘上读写文件的功能都是由
- 在本章中,我们将重点介绍RSA密码加密的不同实现及其所涉及的功能.您可以引用或包含此python文件以实现RSA密码算法实现.加密算法模块&
- 本文实例讲述了Python列表list操作符。分享给大家供大家参考,具体如下:#coding=utf8''''
- 继续flask的学习之旅。今天介绍flask的登陆管理模块,还记得上一篇中的blog小项目么,登录是咱们自己写的验证代码,大概有以下几个步骤
- 大多数网站维护都采用“多人协作,共同管理”方式。某个人负责一个(或者多个)栏目,他只能对他负责的栏目进
- 一、字典转dataFrame1、字典转dataFrame比较简单,直接给出示例:import pandas as pddic = {'
- Access数据库,同时操作大量记录(9500条以上)时报错。错误提示:Microsoft JET Database Engine 错误 &
- 相关推荐:完整的sql中文参考手册(chm)下载 DB2 提供了关连式资料库的查询语言 sql (Structured Query
- 内容摘要:asp使用最多的就是ACCESS数据库和ms sql server数据库,本文列出了asp连接这两个数据库的方
- 本文实例讲述了python 函数的缺省参数使用注意事项。分享给大家供大家参考,具体如下:python的函数支持4种形式的参数:分别是必选参数