利用Pytorch实现ResNet网络构建及模型训练
作者:实力 发布时间:2022-02-24 19:57:59
标签:Pytorch,ResNet,构建网络,模型训练
构建网络
ResNet由一系列堆叠的残差块组成,其主要作用是通过无限制地增加网络深度,从而使其更加强大。在建立ResNet模型之前,让我们先定义4个层,每个层由多个残差块组成。这些层的目的是降低空间尺寸,同时增加通道数量。
以ResNet50为例,我们可以使用以下代码来定义ResNet网络:
class ResNet(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace
(续)
即模型需要在输入层加入一些 normalization 和激活层。
```python
import torch.nn.init as init
class Flatten(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.view(x.size(0), -1)
class ResNet(nn.Module):
def __init__(self, num_classes=1000):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.layer1 = nn.Sequential(
ResidualBlock(64, 256, stride=1),
*[ResidualBlock(256, 256) for _ in range(1, 3)]
)
self.layer2 = nn.Sequential(
ResidualBlock(256, 512, stride=2),
*[ResidualBlock(512, 512) for _ in range(1, 4)]
)
self.layer3 = nn.Sequential(
ResidualBlock(512, 1024, stride=2),
*[ResidualBlock(1024, 1024) for _ in range(1, 6)]
)
self.layer4 = nn.Sequential(
ResidualBlock(1024, 2048, stride=2),
*[ResidualBlock(2048, 2048) for _ in range(1, 3)]
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = Flatten()
self.fc = nn.Linear(2048, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
改进点如下:
我们使用
nn.Sequential
组件,将多个残差块组合成一个功能块(layer)。这样可以方便地修改网络深度,并将其与其他层分离九更容易上手,例如迁移学习中重新训练顶部分类器时。我们在ResNet的输出层添加了标准化和激活函数。它们有助于提高模型的收敛速度并改善性能。
对于
nn.Conv2d
和批标准化层等神经网络组件,我们使用了PyTorch中的内置初始化函数。它们会自动为我们设置好每层的参数。我们还添加了一个
Flatten
层,将4维输出展平为2维张量,以便通过接下来的全连接层进行分类。
训练模型
我们现在已经实现了ResNet50模型,接下来我们将解释如何训练和测试该模型。
首先我们需要定义损失函数和优化器。在这里,我们使用交叉熵损失函数,以及Adam优化器。
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(num_classes=1000).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
在使用PyTorch进行训练时,我们通常会创建一个循环,为每个批次的输入数据计算损失并对模型参数进行更新。以下是该循环的代码:
def train(model, optimizer, criterion, train_loader, device):
model.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100 * correct / total
avg_loss = train_loss / len(train_loader)
return acc, avg_loss
在上面的训练循环中,我们首先通过model.train()
代表进入训练模式。然后使用optimizer.zero_grad()
清除
来源:https://juejin.cn/post/7222862599851475001


猜你喜欢
- 1、遇到的问题:numpy版本im_data = dataset.ReadAsArray(0,0,im_width,im_height)#获
- 如下所示:import numpy as npnp.set_printoptions(threshold='nan')来源:
- 其实也算不上教程,也就是自己没事儿的时候做点东西然后发上来大家交流交流,希望大家不吝赐教^!^因为刚看过亚东的教程和这个有点相似,所以就自己
- 本文实例讲述了Python双精度浮点数运算并分行显示操作。分享给大家供大家参考,具体如下:#coding=utf8def doubleTyp
- 前言有的时候我们在查看数据库数据时,会看到乱码。实际上,无论何种数据库只要出现乱码问题,这大多是由于数据库字符集设定的问题。下面我们就介绍一
- 一、一站式解决 1. 问题分析定位# 找到MySQL的配置文件,复制mysql的数据目录vim /etc/my.cnf# 进入ms
- 接着前面Django入门使用示例今天我们来看看Django是如何加载静态html的?我们首先来看一看什么是静态HTML,什么是动态的HTML
- 一、相关知识点讲解1.1 需要使用的相关库import numpy as npimport pand
- 数据库是应用开发中必须要掌握的技巧,通常在数据库开发过程中,会有两种不同的方式:直接使用SQL语句,这种方式下,直接编写SQL,简单直观,但
- 这里还以前面的微博为例,我们知道拖动刷新的内容由Ajax加载,而且页面的URL没有变化,那么应该到哪里去查看这些Ajax请求呢?1. 查看请
- 如果您还不太了解XML技术,您可以先看看此文:XML的语法、结构以及相关的一些技术 及 XML DOM介绍和例子XML中 CDATA的作用:
- 前言在Python中定义函数,可以用必选参数、默认参数、可变参数和关键字参数,这4种参数都可以一起使用,或者只用其中某些,但是请注意,参数定
- 引用计数Python默认的垃圾收集机制是“引用计数”,每个对象维护了一个ob_ref字段。它的优点是机制简单,当新的引用指向该对象时,引用计
- 不是炒冷饭,我添加了很多新的功能哦演示地址: xwinhtcdemo.htmCSS: global.cssHTC: xwin.htc特点:1
- 今天运行程序时报了SqlSession [org.apache.ibatis.session.defaults.DefaultSqlSess
- 异常是指因为程序出现了错误而在正常控制流以外采取的行动,其分为两个阶段,第一阶段是引发异常的错误,当系统检测到错误并且意识到异常条件,解释器
- 前言在学习操作系统的时候,我们应该都学习过临界区、互斥锁这些概念,用于在并发环境下保证状态的正确性。比如在秒杀时,100 个用户同时抢 10
- 需求:实现ajax请求,在界面上任意地方点击,可以成功传参。创建项目如下所示:settings.py文件的设置,这次我们除了要注册app和设
- 【译者的话】 网页上的小广告(banner)已经成为一种宣传推广的重要形式,但这些小广告除了版面细小外,图象的表现还受到象素较低等其它因素影
- 对于想深入理解 Python 的朋友,很有必要认真看看。喜欢本文点赞支持,欢迎收藏学习。1. eval函数函数的作用:计算指定表达式的值。也