Python实战小项目之Mnist手写数字识别
作者:GSAU-深蓝工作室 发布时间:2023-01-20 23:24:56
标签:Python,Mnist,手写数字识别,实战
程序流程分析图:
传播过程:
代码展示:
创建环境
使用<pip install+包名>来下载torch,torchvision包
准备数据集
设置一次训练所选取的样本数Batch_Sized的值为512,训练此时Epochs的值为8
BATCH_SIZE = 512
EPOCHS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
下载数据集
Normalize()数字归一化,转换使用的值0.1307和0.3081是MNIST数据集的全局平均值和标准偏差,这里我们将它们作为给定值。model
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([.
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
下载测试集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
绘制图像
我们可以使用matplotlib来绘制其中的一些图像
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
print(example_data)
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
plt.title("Ground Truth: {}".format(example_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
搭建神经网络
这里我们构建全连接神经网络,我们使用三个全连接(或线性)层进行前向传播。
class linearNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
x = self.fc3(x)
x = F.log_softmax(x, dim=1)
return x
训练模型
首先,我们需要使用optimizer.zero_grad()手动将梯度设置为零,因为PyTorch在默认情况下会累积梯度。然后,我们生成网络的输出(前向传递),并计算输出与真值标签之间的负对数概率损失。现在,我们收集一组新的梯度,并使用optimizer.step()将其传播回每个网络参数。
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if (batch_idx) % 30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
测试模型
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # 将一批的损失相加
pred = output.max(1, keepdim=True)[1] # 找到概率最大的下标
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
将训练次数进行循环
if __name__ == '__main__':
model = linearNet()
optimizer = optim.Adam(model.parameters())
for epoch in range(1, EPOCHS + 1):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
保存训练模型
torch.save(model, 'MNIST.pth')
运行结果展示:
分享人:苏云云
来源:https://blog.csdn.net/weixin_40604528/article/details/120848106
0
投稿
猜你喜欢
- 一、说明 numpy.ufunc是什么函数?答
- 代码如下:--创建测试表 DECLARE @Users TABLE ( ID INT IDENTITY(1,1), UserIn
- 看看怎样抓到你:<%Dim objCMFUDim strModifiedSet objCMFU 
- A.动态页面第一步:创建转向控制页面,创建网站默认的首页文件(通常为"index.asp"或"default.
- 介绍我们可以通过控制HeaderStyle, RowStyle, AlternatingRowStyle和其他一些属性来改变GridView
- 一、vim python自动补全插件:pydiction 可以实现下面python代码的自动补全:1.简单python关键词补全 2.pyt
- 本文实例为大家分享了python绘制汉诺塔的具体代码,供大家参考,具体内容如下源码:import turtleclass Stack: &n
- Python中pass的作用空语句 do nothing保证格式完整保证语义完整以if语句为例,在c或c++/java中:if(true);
- PHP attributes() 函数实例返回 XML 的 body 元素的属性和值:<?php $note=<<<
- 前言大家好,我们今天来爬取c站的热搜榜,把其文章名称,链接和作者获取下来,我们保存到本地,我们通过测试,发现其实很简单,我们只要简单获取数据
- 导语哈喽!我是木木子,今天又想我了嘛?之前不是出过一期Python美颜相机嘛?不知道你们还记得不?这一期的话话题还是围绕上期关于颜值方面来走
- 在PyTorch中,torch.Tensor类是存储和变换数据的重要工具,相比于Numpy,Tensor提供GPU计算和自动求梯度等更多功能
- python excel文件(.xls文件)如何处理xlrd 用于读取文件,xlwt 用于写入文件,xlutils 是两个工具包的桥梁,也就
- 运行环境Python 2.7操作实例1.原始文本格式:空格分隔的txt,例如2016-03-22 00:06:24.4463094 中文测试
- 上周在去杭州betacafe的路上,有幸和绿人网梁宁和饭统网李耀东、千鸟一道,在出租车上聊起了地理和历史,其中有一个共同的观点是说,人们对事
- 缓存是基于Application实现的CacheState类,建议实例化时用名Cache程序代码<% Class Cache
- 这周心血来潮,翻看了现在比较流行的几个JS脚本框架的底层代码,虽然是走马观花,但也受益良多,感叹先人们的伟大……感叹是为了缓解严肃的气氛并引
- Google Chrome 的发布,使我们更加的注重基于 WebKit 核心的浏览器的表现情况,但我们很多时候“不小心”就会出现问题。考虑下
- jinjia和vue.js默认的模板转义符都是{{}}目前的解决办法是修改vue.js的转义符,将原来的{{}}替换为其他标签,我改为{[]
- 如果需要在数据库中存储图片或视频类的数据,我们可以配置MEDIA.下面的示例将以上传一张图片的形式来说明MEDIA的配置及用法.第一步 se