Python实战小项目之Mnist手写数字识别
作者:GSAU-深蓝工作室 发布时间:2023-01-20 23:24:56
标签:Python,Mnist,手写数字识别,实战
程序流程分析图:
传播过程:
代码展示:
创建环境
使用<pip install+包名>来下载torch,torchvision包
准备数据集
设置一次训练所选取的样本数Batch_Sized的值为512,训练此时Epochs的值为8
BATCH_SIZE = 512
EPOCHS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
下载数据集
Normalize()数字归一化,转换使用的值0.1307和0.3081是MNIST数据集的全局平均值和标准偏差,这里我们将它们作为给定值。model
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([.
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
下载测试集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
绘制图像
我们可以使用matplotlib来绘制其中的一些图像
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
print(example_data)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
搭建神经网络
这里我们构建全连接神经网络,我们使用三个全连接(或线性)层进行前向传播。
class linearNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
x = F.log_softmax(x, dim=1)
return x
训练模型
首先,我们需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度。然后,我们生成网络的输出(前向传递),并计算输出与真值标签之间的负对数概率损失。现在,我们收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数。
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if (batch_idx) % 30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
测试模型
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
将训练次数进行循环
if __name__ == '__main__':
model = linearNet()
optimizer = optim.Adam(model.parameters())
for epoch in range(1, EPOCHS + 1):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
保存训练模型
torch.save(model, 'MNIST.pth')
运行结果展示:
分享人:苏云云
来源:https://blog.csdn.net/weixin_40604528/article/details/120848106


猜你喜欢
- 1 概述1.1 贪心算法贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并不从整体最优考虑,它所作出的选择只是在某种意义上的局部最优选
- 理论介绍分词是自然语言处理的一个基本工作,中文分词和英文不同,字词之间没有空格。中文分词是文本挖掘的基础,对于输入的一段中文,成功的进行中文
- 一、简介Flask是一个轻量级的基于Python的web框架。本文适合有一定HTML、Python、网络基础的同学阅读。这份文档中的代码使用
- 本文主要给大家介绍的是关于利用python模拟实现POST请求提交图片的方法,分享出来供大家参考学习,下面来一看看详细的介绍:使用reque
- 运行多进程 每个子进程的内存空间是互相隔离的 进程之间数据不能共享的互斥锁但是进程之间都是运行在一个操作系统上,进程之间数据不共享,但是共享
- 一个对AJAX的封装//url就是请求的地址//successFunc就是一个请求返回成功之后的一个function,有一个参数,参数就是服
- Etag是URL的Entity Tag,用于标示URL对象是否改变,区分不同语言和Session等等。具体内部含义是使服务器控制的,就像Co
- 这篇文章主要介绍了python scrapy重复执行实现代码详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值
- 由Oralce8.1开始,Oracle增加了一个新的特性就是Stored Outlines,或者称为Plan Stability(计划稳定性
- 概 述 现在有不少介绍利用ASP实现动态分页的文章,方法大同小异,就是每次利用ADO返回原始数据满足条件记录集中的指定
- 本文主要给大家介绍的是关于python爬取散文网文章的相关内容,分享出来供大家参考学习,下面一起来看看详细的介绍:效果图如下:配置pytho
- <?php /** +------------------------------------------------ * 通用的树型
- 《python基础教程》书中的第四个练习,新闻聚合。现在很少见的一类应用,至少我从来没有用过,又叫做Usenet。这个程序的主要功能是用来从
- 核心代码是 getCookie()部分,控制弹框的显示隐藏则在 created()中。<template> <div v-
- 上周用了一周的时间学习了Python和Scrapy,实现了从0到1完整的网页爬虫实现。研究的时候很痛苦,但是很享受,做技术的嘛。首先,安装P
- 相信玩过爬虫的朋友都知道selenium,一个自动化测试的神器工具。写个Python自动化脚本解放双手基本上是常规的操作了,爬虫爬不了的,就
- 本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052一、强大的 hub
- 参考官网地址:Windows端:https://tensorflow.google.cn/install/source_windowsCPU
- 1.加载数据库,数据库的配置不能写死在seting.py文件中,下面的方式是读取另外一个文件,配置数据库:config = '
- 一、写在开头哈喽兄弟们之前经常编写Python脚本来进行数据处理、数据传输和模型训练。随着数据量和数据复杂性的增加,运行脚本可能需要一些时间