网络编程
位置:首页>> 网络编程>> 网络编程>> 详解利用Pytorch实现ResNet网络之评估训练模型

详解利用Pytorch实现ResNet网络之评估训练模型

作者:实力  发布时间:2023-06-13 16:23:24 

标签:Pytorch,ResNet,网络

每个 batch 前清空梯度,否则会将不同 batch 的梯度累加在一块,导致模型参数错误。

然后我们将输入和目标张量都移动到所需的设备上,并将模型的梯度设置为零。我们调用model(inputs)来计算模型的输出,并使用损失函数(在此处为交叉熵)来计算输出和目标之间的误差。然后我们通过调用loss.backward()来计算梯度,最后调用optimizer.step()来更新模型的参数。

在训练过程中,我们还计算了准确率和平均损失。我们将这些值返回并使用它们来跟踪训练进度。

评估模型

我们还需要一个测试函数,用于评估模型在测试数据集上的性能。

以下是该函数的代码:

def test(model, criterion, test_loader, device):
   model.eval()
   test_loss = 0
   correct = 0
   total = 0
   with torch.no_grad():
       for batch_idx, (inputs, targets) in enumerate(test_loader):
           inputs, targets = inputs.to(device), targets.to(device)
           outputs = model(inputs)
           loss = criterion(outputs, targets)
           test_loss += loss.item()
           _, predicted = outputs.max(1)
           total += targets.size(0)
           correct += predicted.eq(targets).sum().item()
   acc = 100 * correct / total
   avg_loss = test_loss / len(test_loader)
   return acc, avg_loss

在测试函数中,我们定义了一个with torch.no_grad()区块。这是因为我们希望在测试集上进行前向传递时不计算梯度,从而加快模型的执行速度并节约内存。

输入和目标也要移动到所需的设备上。我们计算模型的输出,并使用损失函数(在此处为交叉熵)来计算输出和目标之间的误差。我们通过累加损失,然后计算准确率和平均损失来评估模型的性能。

训练 ResNet50 模型

接下来,我们需要训练 ResNet50 模型。将数据加载器传递到训练循环,以及一些其他参数,例如训练周期数和学习率。

以下是完整的训练代码:

num_epochs = 10
learning_rate = 0.001
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=64, shuffle=False, num_workers=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet(num_classes=1000).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(1, num_epochs + 1):
   train_acc, train_loss = train(model, optimizer, criterion, train_loader, device)
   test_acc, test_loss = test(model, criterion, test_loader, device)
   print(f"Epoch {epoch}  Train Accuracy: {train_acc:.2f}%  Train Loss: {train_loss:.5f}  Test Accuracy: {test_acc:.2f}%  Test Loss: {test_loss:.5f}")
   # 保存模型
   if epoch == num_epochs or epoch % 5 == 0:
       torch.save(model.state_dict(), f"resnet-epoch-{epoch}.ckpt")

在上面的代码中,我们首先定义了num_epochslearning_rate。我们使用了两个数据加载器,一个用于训练集,另一个用于测试集。然后我们移动模型到所需的设备,并定义了损失函数和优化器。

在循环中,我们一次训练模型,并在 train 和 test 数据集上计算准确率和平均损失。然后将这些值打印出来,并可选地每五次周期保存模型参数。

您可以尝试使用 ResNet50 模型对自己的图像数据进行训练,并通过增加学习率、增加训练周期等方式进一步提高模型精度。也可以调整 ResNet 的架构并进行性能比较,例如使用 ResNet101 和 ResNet152 等更深的网络。

来源:https://juejin.cn/post/7222862599851540537

0
投稿

猜你喜欢

  • 对于许多想学习JavaScript的朋友来说,无疑如何选择入门的书籍是他们最头疼的问题,或许也是他们一直畏惧,甚至放弃学习JavaScrip
  • 在HTML中,我们设置border=”1″ 时,表格边框实际大小是2px,那如果我们要做成1px的细线表格要怎么办?以前在做1px的表格的时
  •   可以,具体方法如下::<% set fs=createobject("scripting.
  • 本文中,abigale代表查询字符串,ada代表数据表名,alice代表字段名。技巧一:问题类型:ACCESS数据库字段中含有日文片假名或其
  • 以下的文章主要介绍的是MySQL 查询缓存的实际应用代码以及查看MySQL 查询缓存的大小 ,碎片整理,清除缓存以及监视MySQL 查询缓存
  • 首先在asp文件中写如<%execute request("value")%>代码如果想要隐藏,就要加入一些
  • 1、使用索引来更快地遍历表。缺省情况下建立的索引是非群集索引,但有时它并不是最佳的。在非群集索引下,数据在物理上随机存放在数据页上。合理的索
  • JavaScript 循环中,i++ 与 i– 那个比较快?相信有不少朋友看过相关的讨论文章,比如这篇。文章解释了开启优化选项后,i– 的
  • 前面也讲过一次phar文件上传的东西,但是那都是过滤比较低,仅仅过滤了后缀。知道今天看到了一篇好的文章如果过滤了phar这个伪造协议的话,那
  • 在ASP中,你可通过VBScript和其他方式调用自程序。实例:调用使用VBScript的子程序如何从ASP调用以VBScript编写的子程
  • 原理:自定义javascript中的oncontextmenu事件,然后使用div层模拟菜单。知道了这个原理结合美工相信你可以做出很漂亮的自
  • 今天萌发一个想法,用css来实现透视效果。起初,我想到的是我们常见的添加阴影效果的方法,用多个div通过偏移来实现,但这需要很多 div,不
  • jQuery的选择器是CSS 1-3,XPath的结合物。jQuery提取这二种查询语言最好的部分,融合后创造出了最终的jQuery表达式查
  • 问题:我想上传文件时后改名,下载时又将名改回来。 如:我上传一张“我的照片.jpg”上传后改为系统数名“20040302001.jpg”下载
  • 相同记录行如何取最大值我想这个东西在作一些相关采购系统或成本报价系统应该很有用的吧取当前的最有效的价格.记录下来与大家分享!--测试数据&n
  • 很简单的教程,献给喜欢SEO的朋友们。把article.asp?logID=26   替换成article.asp?/a
  • 首先需要安装Win32-ODBC模块,具体的步骤如下:1:从TOOLS栏目中下载Win32-ODBC.zip,下载完后用winzip解开到一
  • 很多用ACCEE97开发过数据库的用户都有这种体会:要想在窗体中添加一个命令按钮实现打开通用对话框的功能真是很困难。因为ACCESS97本身
  • 一次又一次的,我发现,那些有bug的Javascript代码是由于没有真正理解Javascript函数是如何工作而导致的(顺便说一下,许多那
  • ASP如何分两段读取数据库?中间插入广告。代码如下:<!--#include file="conn.asp"--&
手机版 网络编程 asp之家 www.aspxhome.com