详解利用Pytorch实现ResNet网络之评估训练模型
作者:实力 发布时间:2023-06-13 16:23:24
每个 batch 前清空梯度,否则会将不同 batch 的梯度累加在一块,导致模型参数错误。
然后我们将输入和目标张量都移动到所需的设备上,并将模型的梯度设置为零。我们调用model(inputs)
来计算模型的输出,并使用损失函数(在此处为交叉熵)来计算输出和目标之间的误差。然后我们通过调用loss.backward()
来计算梯度,最后调用optimizer.step()
来更新模型的参数。
在训练过程中,我们还计算了准确率和平均损失。我们将这些值返回并使用它们来跟踪训练进度。
评估模型
我们还需要一个测试函数,用于评估模型在测试数据集上的性能。
以下是该函数的代码:
def test(model, criterion, test_loader, device):
model.eval()
test_loss = 0
correct = 0
total = 0
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
acc = 100 * correct / total
avg_loss = test_loss / len(test_loader)
return acc, avg_loss
在测试函数中,我们定义了一个with torch.no_grad()
区块。这是因为我们希望在测试集上进行前向传递时不计算梯度,从而加快模型的执行速度并节约内存。
输入和目标也要移动到所需的设备上。我们计算模型的输出,并使用损失函数(在此处为交叉熵)来计算输出和目标之间的误差。我们通过累加损失,然后计算准确率和平均损失来评估模型的性能。
训练 ResNet50 模型
接下来,我们需要训练 ResNet50 模型。将数据加载器传递到训练循环,以及一些其他参数,例如训练周期数和学习率。
以下是完整的训练代码:
num_epochs = 10
learning_rate = 0.001
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(num_classes=1000).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, num_epochs + 1):
train_acc, train_loss = train(model, optimizer, criterion, train_loader, device)
test_acc, test_loss = test(model, criterion, test_loader, device)
print(f"Epoch {epoch} Train Accuracy: {train_acc:.2f}% Train Loss: {train_loss:.5f} Test Accuracy: {test_acc:.2f}% Test Loss: {test_loss:.5f}")
# 保存模型
if epoch == num_epochs or epoch % 5 == 0:
torch.save(model.state_dict(), f"resnet-epoch-{epoch}.ckpt")
在上面的代码中,我们首先定义了num_epochs
和learning_rate
。我们使用了两个数据加载器,一个用于训练集,另一个用于测试集。然后我们移动模型到所需的设备,并定义了损失函数和优化器。
在循环中,我们一次训练模型,并在 train 和 test 数据集上计算准确率和平均损失。然后将这些值打印出来,并可选地每五次周期保存模型参数。
您可以尝试使用 ResNet50 模型对自己的图像数据进行训练,并通过增加学习率、增加训练周期等方式进一步提高模型精度。也可以调整 ResNet 的架构并进行性能比较,例如使用 ResNet101 和 ResNet152 等更深的网络。
来源:https://juejin.cn/post/7222862599851540537


猜你喜欢
- 一、整数python2中整形可以分为一般整形和长整形,但是在python3中,两者以及合二为一了,只有整形。python中的整形是具有无限精
- 思想:4个数字的排列,加上3个运算符的排列,使用后缀表达式的表现如下:情形一:1,2,3,4,+,-,* => 24*24*4情形二:
- 关于Tensor的数据类型说明1. 32位浮点型:torch.FloatTensora=torch.Tensor( [[2,3],[4,8]
- 所需库的安装很多人问Pytorch要怎么可视化,于是决定搞一篇。tensorboardX==2.0tensorflow==1.13.2由于t
- 前言首先,我们开发的项目会有多个版本.其次,我们的项目版本会随着更新越来越多,我们不可能因出了新版本就不维护旧版本了.那么,我们就需要对版本
- 本文实例讲述了Python 面向对象静态方法、类方法、属性方法知识点。分享给大家供大家参考,具体如下:(1)静态方法--》-@staticm
- 某天,在需要抓取某个网页信息的时候,需要在header中增加一些信息,于是搜索了一下,如何在golang发起的http请求中设置header
- 事务日志(Transaction logs)是数据库结构中非常重要但又经常被忽略的部分。由于它并不像数据库中的schema那样活跃,因此很少
- 使用Django的ORM操作的时候,想要获取本条,上一条,下一条。初步的想法是写3个ORM,3个ORM如下:本条:models.Obj.ob
- 理论介绍分词是自然语言处理的一个基本工作,中文分词和英文不同,字词之间没有空格。中文分词是文本挖掘的基础,对于输入的一段中文,成功的进行中文
- CPU-bound(计算密集型) 和I/O bound(I/O密集型)计算密集型任务(CPU-bound) 的特点是要进行大量的计算,占据着
- 字典求和edge_weights = defaultdict(lambda: defaultdict(float))for idx,node
- 拆包是指将一个结构中的数据拆分为多个单独变量中。以元组为例:>>> a = ('windows', 10,
- 和单选框一样,许多新手在用 Javascript 验证表单(form)中多选框(checkbox)的值时,都会遇到问题,原因是 checkb
- 1.sort()方法sort()是列表的方法,修改原列表使得它按照大小排序,没有返回值,返回NoneIn [90]: x = [4, 6,
- 本文实例讲述了Python实现连接MySql数据库及增删改查操作。分享给大家供大家参考,具体如下:在本文中介绍 Python3 使用PyMy
- 本文实例讲述了GO语言实现简单的目录复制功能。分享给大家供大家参考。具体实现方法如下:创建一个独立的 goroutine 遍历文件,主进程负
- 前言大家都知道PHP 的页面静态化有多种实现方式,比如使用输出缓冲(output buffering),该种方式是把数据缓存在 PHP 的缓
- 导读只需要添加几行代码,就可以得到更快速,更省显存的PyTorch模型。你知道吗,在1986年Geoffrey Hinton就在Nature
- 通过本文给大家介绍Python3控制路由器——使用requests重启极路由.py的相关知识,代码写了相应的注释,以后再写成可以方便调用的模