网络编程
位置:首页>> 网络编程>> Python编程>> 详解PyTorch手写数字识别(MNIST数据集)

详解PyTorch手写数字识别(MNIST数据集)

作者:Steven·简谈  发布时间:2023-01-28 19:40:47 

标签:PyTorch,MNIST,手写数字识别

MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。

导入相关库


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下载并导入数据集

cv2 用于展示数据的图像

获取训练集和测试集


# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
               train=True,
               transform=transforms.ToTensor(),
               download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
              train=False,
              transform=transforms.ToTensor(),
              download=True)

root 用于指定数据集在下载之后的存放路径

transform 用于指定导入数据集需要对数据进行那种变化操作

train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分

download 为 True 表示数据集需要程序自动帮你下载

这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。

数据装载和预览


# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包

# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                     batch_size=batch_size,
                     shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                    batch_size=batch_size,
                    shuffle=True)

在装载完成后,可以选取其中一个批次的数据进行预览:


images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。

预览图片如下:

详解PyTorch手写数字识别(MNIST数据集)

并且打印出了图片相对应的数字:

详解PyTorch手写数字识别(MNIST数据集)

搭建神经网络


# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear

class LeNet(nn.Module):
 def __init__(self):
   super(LeNet, self).__init__()
   self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                 nn.MaxPool2d(2, 2))

self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                 nn.MaxPool2d(2, 2))

self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                nn.BatchNorm1d(120), nn.ReLU())

self.fc2 = nn.Sequential(
     nn.Linear(120, 84),
     nn.BatchNorm1d(84),
     nn.ReLU(),
     nn.Linear(84, 10))
   # 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9

def forward(self, x):
   x = self.conv1(x)
   x = self.conv2(x)
   x = x.view(x.size()[0], -1)
   x = self.fc1(x)
   x = self.fc2(x)
   x = self.fc3(x)
   return x

前向传播内容:

首先经过 self.conv1() 和 self.conv1() 进行卷积处理

然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)

最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类

训练模型


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
 net.parameters(),
 lr=LR,
)

epoch = 1
if __name__ == '__main__':
 for epoch in range(epoch):
   sum_loss = 0.0
   for i, data in enumerate(train_loader):
     inputs, labels = data
     inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
     optimizer.zero_grad() #将梯度归零
     outputs = net(inputs) #将数据传入网络进行前向运算
     loss = criterion(outputs, labels) #得到损失函数
     loss.backward() #反向传播
     optimizer.step() #通过梯度做一步参数更新

# print(loss)
     sum_loss += loss.item()
     if i % 100 == 99:
       print('[%d,%d] loss:%.03f' %
          (epoch + 1, i + 1, sum_loss / 100))
       sum_loss = 0.0

测试模型


net.eval() #将模型变换为测试模式
 correct = 0
 total = 0
 for data_test in test_loader:
   images, labels = data_test
   images, labels = Variable(images).cuda(), Variable(labels).cuda()
   output_test = net(images)
   _, predicted = torch.max(output_test, 1)
   total += labels.size(0)
   correct += (predicted == labels).sum()
 print("correct1: ", correct)
 print("Test acc: {0}".format(correct.item() /
                len(test_dataset)))

训练及测试的情况:

详解PyTorch手写数字识别(MNIST数据集)

98% 以上的成功率,效果还不错。

来源:https://blog.csdn.net/weixin_44613063/article/details/90815082

0
投稿

猜你喜欢

  • vue中代码的复用, 为我们提供了 mixnis. 模板的复用, 为我们提供了 插槽( slot )插槽的分类默认插槽具名插槽作用域插槽当我
  • 一、PIL的基本概念:PIL中所涉及的基本概念有如下几个:通道(bands)、模式(mode)、尺寸(size)、坐标系统(coordina
  • 问题你想将HTML或者XML实体如 &entity; 或 &#code; 替换为对应的文本。 再者,你需要转换文本 * 定的字
  • 一、前言在写业务代码时候,有许多场景需要重试某块业务逻辑,例如网络请求、购物下单等,希望发生异常的时候多重试几次。本文分享如何利用Pytho
  • 项目用run dev build 打包后,发现很多图片都不显示,在本地是没有问题的啊!找原因发现通过webpack+vuecli默认打包的c
  • 验证关键词是否为sql保留字的在线工具:<html>      <head><t
  • 虽然并非你编写的每个 Python 程序都要求一个严格的性能分析,但是让人放心的是,当问题发生的时候,Python 生态圈有各种各样的工具可
  • MYSQL中的分组和链接是在操作数据库和数据交互时最常用的两个在功能,把这两项处理好了,MYSQL的执行效率会非常高速。一、group by
  • 本机环境: Windows 10服务器环境: Windows Server 2012 R2背景:公司需要我开发一个简单的web应用。开发的时
  • 1. 引言在数据处理、机器学习等领域,我们经常需要对各式各样的数据进行处理,本文重点介绍三种非常简单的方法来检测数据集中的异常值。 
  • Harris 角点检测算法1. 角点角点是水平方向、垂直方向变化都很大的像素。角点检测算法的基本思想:    
  •   这主要是写给我自己的,防止以后入坑,耗费时间。本文主要谈的是怎样安装Python解释器和Python开发工具PyCharm。  本机系统
  • golang作为一热门的兼顾性能 效率的热门语言,相信很多人都知道,在编程语言排行榜上一直都是很亮眼,作为一门强类型语言,二进制位的操作肯定
  • 数据合并是数据处理过程中的必经环节,pandas作为数据分析的利器,提供了四种常用的数据合并方式,让我们看看如何使用这些方法吧!1.conc
  • Python读写word文档有现成的库可以处理。我这里采用 python-docx。可以用pip install python-docx安装
  • 我们或多或少都使用过各式各样的富文本编辑器,其中有一个很方便功能,复制一张图片然后粘贴进文本框,这张图片就被上传了,那么这个方便的功能是如何
  • 通过使用turtle绘画象棋棋盘,供大家参考,具体内容如下# 绘制象棋棋盘import turtlet = turtle.Pen()t.wi
  • Python GUI 库有很多,下面给大家罗列常用的几种 GUI 库。下面介绍的这些GUI框架,能满足大部分开发人员的需要,你可以根据自己的
  • 在自动化测试过程中,有时后会遇到元素定位方式没有问题,但是依旧抛出无法找到元素的异常的问题,通常情况下,如果元素定位没有问题,但还是无法找到
  • 定位篇UI 自动化很多时候的苦恼都是定位不到,其实说实话我到现在有时候也是莫名其妙的定位到或者定位不到。好在这个框架定位方式的上限非常以及特
手机版 网络编程 asp之家 www.aspxhome.com