Python深度学习之使用Pytorch搭建ShuffleNetv2
作者:I 发布时间:2023-10-10 06:19:09
标签:Python,Pytorch,ShuffleNetv2
一、model.py
1.1 Channel Shuffle
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# reshape
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
1.2 block
class InvertedResidual(nn.Module):
def __init__(self, input_c: int, output_c: int, stride: int):
super(InvertedResidual, self).__init__()
if stride not in [1, 2]:
raise ValueError("illegal stride value.")
self.stride = stride
assert output_c % 2 == 0
branch_features = output_c // 2
# 当stride为1时,input_channel应该是branch_features的两倍
# python中 '<<' 是位运算,可理解为计算×2的快速方法
assert (self.stride != 1) or (input_c == branch_features << 1)
if self.stride == 2:
self.branch1 = nn.Sequential(
self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(input_c),
nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
@staticmethod
def depthwise_conv(input_c: int,
output_c: int,
kernel_s: int,
stride: int = 1,
padding: int = 0,
bias: bool = False) -> nn.Conv2d:
return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
stride=stride, padding=padding, bias=bias, groups=input_c)
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
1.3 shufflenet v2
class ShuffleNetV2(nn.Module):
def __init__(self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual):
super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels
# input RGB image
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(stage_names, stages_repeats,
self._stage_out_channels[1:]):
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # global pool
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
二、train.py
来源:https://blog.csdn.net/weixin_43154149/article/details/116267653
0
投稿
猜你喜欢
- asp+js做的一个dig程序中的投票(有的叫顶一下,踩一下),由于代码较长,只贴出核心部分:投票中的代码相关文章推荐:ajax +asp
- 发现一个非常强的CSS在线排版:CSS Text Wrapper只要你拖拽线条,你就可以得到你想要的文字版式CSS代码。可以让想让文本块呈现
- 前言Python快捷创建文件夹和文件详解 自己做文件时发现 简单的反复操作十分浪费时间,于是想到了 使用Python,这个分享给
- 本文实例为大家分享了TensorFlow实现简单线性回归的具体代码,供大家参考,具体内容如下简单的一元线性回归一元线性回归公式:其中x是特征
- 判断是否xx开始使用startswith示例代码:String = "12345 上山打老虎"if str(String
- 简介学习慕课课程,Flask前后端分离API后台接口的实现demo,前端可以接入小程序,暂时已经完成后台API基础架构,使用 postman
- 组合数据类型分类组合数据类型分为三类,第一类是集合类型,第二类是序列类型,第三类是映射类型集合类型集合类型是一个元素集合,元素之间没有排列顺
- 有朋友问,在数据库中如何查询数据所在的行,一般我们建议一个自增字段就可以了.但是有时却会删除数据,那么那个自增字段也不正确了先不管朋友们为什
- 第二次修改models.py以后再次python manage.py makemigrations提示如下You are trying to
- 前言:由于很多业务表因为历史原因或者性能原因,都使用了违反第一范式的设计模式。即同一个列中存储了多个属性值(具体结构见下表)。这种模式下,应
- 在开发django应用的过程中,使用开发者模式启动服务是特别方便的一件事,只需要 python manage.py runserver 就可
- Selenium 是一个用于Web应用程序测试的工具。Selenium测试直接运行在浏览器中,就像真正的用户在操作一样。支持的浏览器包括IE
- 1. 前言中文分词≠自然语言处理!HanlpHanLP是由一系列模型与算法组成的Java工具包,目标是普及自然语言处理在生产环境中的应用。H
- json 作为一种通用的编解码协议,可阅读性上比 thrift,protobuf 等协议要好一些,同时编码的 size 也会比 xml 这类
- 谈到比特币,我们都知道挖矿,有些人并不太明白挖矿的含义。这里的挖矿其实就是哈希的碰撞,举个简单例子:import hashlibx = 11
- 运行MySQL Server 5.0安装程序“setup.exe”,出现如下界面: 安装向导启动,按“Next”继续:
- 一、安装pip install pymysql二、连接数据库三种连接数据库的方式import pymysql# 方式一conn = pymy
- 一、前言提到 limit 优化,大多数 MySQL DBA 都不会陌生,能想到各种应对策略,比如延迟关联,书签式查询等等,之前我也写过一篇优
- 将数据库中的信息存储至XML文件中:save.asp<!-- #include file="adovbs
- 一、re.findall函数介绍它在re.py中有定义:def findall(pattern, string, flags=0): &nb