详解PyTorch手写数字识别(MNIST数据集)
作者:Steven·简谈 发布时间:2023-01-28 19:40:47
MNIST 手写数字识别是一个比较简单的入门项目,相当于深度学习中的 Hello World,可以让我们快速了解构建神经网络的大致过程。虽然网上的案例比较多,但还是要自己实现一遍。代码采用 PyTorch 1.0 编写并运行。
导入相关库
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2
torchvision 用于下载并导入数据集
cv2 用于展示数据的图像
获取训练集和测试集
# 下载训练集
train_dataset = datasets.MNIST(root='./num/',
train=True,
transform=transforms.ToTensor(),
download=True)
# 下载测试集
test_dataset = datasets.MNIST(root='./num/',
train=False,
transform=transforms.ToTensor(),
download=True)
root 用于指定数据集在下载之后的存放路径
transform 用于指定导入数据集需要对数据进行那种变化操作
train是指定在数据集下载完成后需要载入的那部分数据,设置为 True 则说明载入的是该数据集的训练集部分,设置为 False 则说明载入的是该数据集的测试集部分
download 为 True 表示数据集需要程序自动帮你下载
这样设置并运行后,就会在指定路径中下载 MNIST 数据集,之后就可以使用了。
数据装载和预览
# dataset 参数用于指定我们载入的数据集名称
# batch_size参数设置了每个包中的图片数据个数
# 在装载的过程会将数据随机打乱顺序并进打包
# 装载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 装载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
在装载完成后,可以选取其中一个批次的数据进行预览:
images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)
img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)
在以上代码中使用了 iter 和 next 来获取取一个批次的图片数据和其对应的图片标签,然后使用 torchvision.utils 中的 make_grid 类方法将一个批次的图片构造成网格模式。
预览图片如下:
并且打印出了图片相对应的数字:
搭建神经网络
# 卷积层使用 torch.nn.Conv2d
# 激活层使用 torch.nn.ReLU
# 池化层使用 torch.nn.MaxPool2d
# 全连接层使用 torch.nn.Linear
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84, 10))
# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
前向传播内容:
首先经过 self.conv1() 和 self.conv1() 进行卷积处理
然后进行 x = x.view(x.size()[0], -1),对参数实现扁平化(便于后面全连接层输入)
最后通过 self.fc1() 和 self.fc2() 定义的全连接层进行最后的分类
训练模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001
net = LeNet().to(device)
# 损失函数使用交叉熵
criterion = nn.CrossEntropyLoss()
# 优化函数使用 Adam 自适应优化算法
optimizer = optim.Adam(
net.parameters(),
lr=LR,
)
epoch = 1
if __name__ == '__main__':
for epoch in range(epoch):
sum_loss = 0.0
for i, data in enumerate(train_loader):
inputs, labels = data
inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
optimizer.zero_grad() #将梯度归零
outputs = net(inputs) #将数据传入网络进行前向运算
loss = criterion(outputs, labels) #得到损失函数
loss.backward() #反向传播
optimizer.step() #通过梯度做一步参数更新
# print(loss)
sum_loss += loss.item()
if i % 100 == 99:
print('[%d,%d] loss:%.03f' %
(epoch + 1, i + 1, sum_loss / 100))
sum_loss = 0.0
测试模型
net.eval() #将模型变换为测试模式
correct = 0
total = 0
for data_test in test_loader:
images, labels = data_test
images, labels = Variable(images).cuda(), Variable(labels).cuda()
output_test = net(images)
_, predicted = torch.max(output_test, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print("correct1: ", correct)
print("Test acc: {0}".format(correct.item() /
len(test_dataset)))
训练及测试的情况:
98% 以上的成功率,效果还不错。
来源:https://blog.csdn.net/weixin_44613063/article/details/90815082
猜你喜欢
- 一、Pycharm安装Django框架二、新建Django项目1、manage.py是个管理角色,拥有的功能包括:(1)创建app: pyt
- 字符画,一种由字母、标点、汉字或其他字符组成的图画。简单的字符画是利用字符的形状代替图画的线条来构成简单的人物、事物等形象,它一般由人工制作
- 然后给脚本文件运行权限,方法(1)chmod +x ./*.py方法(2)chmod 755 ./*.py (777也无所谓啦)这个命令不去
- 前言许多 Web 应用依赖大量的 I/O (输入/输出) 操作,比如从网站上下载图片、视频等内容;进行网络聊天或者针对后台数据库进行多次查询
- 上周 RealWorld CTF 2018 web 题 bookhub 有个未授权访问的漏洞,比较有意思,赛后看了一下公开的 WriteUp
- 题记JS中的this指向一直是个让初学者头疼的问题。今天,我们就一起来瞅瞅this倒地是咋回事,详细说说this指向原则,从此不再为了thi
- 如下所示:import numpy as npimport pandas as pdfrom pandas import Series,Da
- 前言在浏览博客时,偶然看到了用python将汉字转为拼音的第三方包,但是在实现的过程中发现一些参数已经更新,现在将两种方法记录一下。xpin
- 1) 用正式表达式 regexp "[u0391-uFFE5]"2) 用length和char_lengthdrop t
- 下列语句部分是Mssql语句,不可以在access中使用。SQL分类:DDL—数据定义语言(CREATE,ALTER,DROP,DECLAR
- python判断图片主色调,单个颜色:#!/usr/bin/env python# -*- coding: utf-8 -*-import
- 一、实际场景及解决思路实际场景:比如某个班的数学成绩以字典格式存储为:student_dict = {'xiaoliang'
- 1. 抓取街拍图片街拍图片网址2. 分析街拍图片结构keyword: 街拍pd: atlasdvpf: pcaid: 4916page_nu
- 比如在学习list、tuple、dict、str、os、sys等模组的时候,利用Python的自带文档可以很快速的全面的学到那些处理的函数。
- 这是我为了学习tkinter用python 写的一个下载m3u8视频的小程序,程序使用了多线程下载,下载后自动合并成一个视频文件,方便播放。
- 最近在使用Tensorflow 实现DNN网络时,遇到一些问题。目前网上关于Tensorflow的资料还比较少,现把问题和解决方法写出来,仅
- 简介pip 是 Python 的包安装程序。其实,pip 就是 Python 标准库(The Python Standard Library
- 如果你正在负责一个基于SQL Server的项目,或者你刚刚接触SQL Server,你都有可能要面临一些数据库性能的问题,这篇文章会为你提
- 问题一:安装模块时出现报错 Microsoft Visual C++ 14.0 is required,也下载安装了运行库依然还是
- 使用python脚本实现查询火车票信息的效果图如下:实现的代码:# coding: utf-8"""命令行火车