PyTorch实现MNIST数据集手写数字识别详情
作者:长浔 发布时间:2021-08-03 17:30:36
标签:PyTorch,MNIST,数据集,数字,识别
前言:
本篇文章基于卷积神经网络CNN,使用PyTorch实现MNIST数据集手写数字识别。
一、PyTorch是什么?
PyTorch 是一个 Torch7 团队开源的 Python 优先的深度学习框架,提供两个高级功能:
强大的 GPU 加速 Tensor 计算(类似 numpy)
构建基于 tape 的自动升级系统上的深度神经网络
你可以重用你喜欢的 python 包,如 numpy、scipy 和 Cython ,在需要时扩展 PyTorch。
二、程序示例
下面案例可供运行参考
1.引入必要库
import torchvision
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
2.下载数据集
这里设置download=True,将会自动下载数据集,并存储在./data文件夹。
train_data = torchvision.datasets.MNIST(root="./data",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.MNIST(root="./data",train=False,transform=torchvision.transforms.ToTensor(),download=True)
3.加载数据集
batch_size=32表示每一个batch中包含32张手写数字图片,shuffle=True表示打乱测试集(data和target仍一一对应)
train_loader = DataLoader(train_data,batch_size=32,shuffle=True)
test_loader = DataLoader(test_data,batch_size=32,shuffle=False)
4.搭建CNN模型并实例化
class Net(torch.nn.Module):
def __init__(self):
super(Net,self).__init__()
self.con1 = torch.nn.Conv2d(1,10,kernel_size=5)
self.con2 = torch.nn.Conv2d(10,20,kernel_size=5)
self.pooling = torch.nn.MaxPool2d(2)
self.fc = torch.nn.Linear(320,10)
def forward(self,x):
batch_size = x.size(0)
x = F.relu(self.pooling(self.con1(x)))
x = F.relu(self.pooling(self.con2(x)))
x = x.view(batch_size,-1)
x = self.fc(x)
return x
#模型实例化
model = Net()
5.交叉熵损失函数损失函数及SGD算法优化器
lossfun = torch.nn.CrossEntropyLoss()
opt = torch.optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
6.训练函数
def train(epoch):
running_loss = 0.0
for i,(inputs,targets) in enumerate(train_loader,0):
# inputs,targets = inputs.to(device),targets.to(device)
opt.zero_grad()
outputs = model(inputs)
loss = lossfun(outputs,targets)
loss.backward()
opt.step()
running_loss += loss.item()
if i % 300 == 299:
print('[%d,%d] loss:%.3f' % (epoch+1,i+1,running_loss/300))
running_loss = 0.0
7.测试函数
def test():
total = 0
correct = 0
with torch.no_grad():
for (inputs,targets) in test_loader:
# inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
_,predicted = torch.max(outputs.data,dim=1)
total += targets.size(0)
correct += (predicted == targets).sum().item()
print(100*correct/total)
8.运行
if __name__ == '__main__':
for epoch in range(20):
train(epoch)
test()
三、总结
来源:https://blog.csdn.net/qq_41664447/article/details/126698428


猜你喜欢
- 其实 Oracle数据库的分页还是比较容易理解的。此文以oracle数据库中的SCOTT用户的EMP表为例,用PL/SQL Develope
- Python中的数据可视化matplotlib 是python最著名的绘图库,它提供了一整套和matlab相似的命令API,十分适合交互式地
- 搞前端应该对语义化并不陌生,每天都在说语义化,可什么是语义化,语义化究竟能给我们带来什么好处?参加web标准交流会的时候我向各位同学提出了我
- 对于一个Dict:test_dict = {1:5, 2:4, 3:3, 4:2, 5:1}想要求key值大于等于3的所有项:print({
- 在平时的需求开发中涉及到将多列值合并为一列值的操作,通过查阅相关资料特此记录以下方法,方便日后学习复盘 import pandas
- 一个asp读取数据库中数据到数组的类,仅供参考!DbPath = "test.mdb"’数据库位置&
- 在SQL查询中,关键词Like可提供模糊查询功能,它通常与通配符一起使用。1 Like条件适用数据库字段类型 &nbs
- 导语昨天下班,回家吃完饭就直接躺了,无聊的时候大家都会干什么呢?当然是刷刷刷——抖音啦,嗯哼,然后返现了抖音上一款特效——「变身漫画」,简直
- 本文实例讲述了Python Django框架单元测试之文件上传测试。分享给大家供大家参考,具体如下:Submitting files is
- 目录Python里的dict和set的效率有多高?字典中的散列表1.散列值和相等性散列表算法dict的实现及其导致的结果1.键必须死可散列的
- python selenium 获取接口数据。selenium没有直接提供查询的函数,但是可以通过webdriver提供的API查询,使用的
- 最近在学习python著名的绘图包matplotlib时发现,有时候图例等设置无法正常显示中文,于是就想把这个问题解决了。PS:本文仅针对W
- 本文实例讲述了Python中XlsxWriter模块用法。分享给大家供大家参考,具体如下:XlsxWriter,可以生成excel文件(xl
- Python类的动态修改的实例方法相信很多朋友在编程的时候都会想修改一下已经写好的程序行为代码,而最常见的方式就是通过子类来重写父类的一些不
- 背景sy项目通过MQ接受业务系统的业务数据,通过运行开发者开发的python脚本执行业务系统与财务系统数据的一致性校验。sy系统需要每天运行
- python编程中常用的12种基础知识总结:正则表达式替换,遍历目录方法,列表按列排序、去重,字典排序,字典、列表、字符串互转,时间对象操作
- 本文实例讲述了Flask框架使用DBUtils模块连接数据库的操作方法。分享给大家供大家参考,具体如下:Flask连接数据库数据库连接池:D
- 因为刚学vue然后自己自习了一下axios,然后想写一个简单的查询后台数据<tr v-for=" user in uList
- 假设有如下目录结构:-- dir0| file1.py| file2.py| dir3| file3.py| dir4| file4.pyd
- 一、深复制与浅复制列表是Python中自带的一种数据结构,在使用列表时,拷贝操作不可避免,下面简单讨论一下列表的深复制(拷贝)与浅复制首先看