Pytorch实现的手写数字mnist识别功能完整示例
作者:nudt_qxx 发布时间:2022-10-15 23:38:22
标签:Pytorch,手写数字,mnist,识别
本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义网络结构
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Sequential( #input_size=(1*28*28)
nn.Conv2d(1, 6, 5, 1, 2), #padding=2保证输入输出尺寸相同
nn.ReLU(), #input_size=(6*28*28)
nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
)
self.conv2 = nn.Sequential(
nn.Conv2d(6, 16, 5),
nn.ReLU(), #input_size=(16*10*10)
nn.MaxPool2d(2, 2) #output_size=(16*5*5)
)
self.fc1 = nn.Sequential(
nn.Linear(16 * 5 * 5, 120),
nn.ReLU()
)
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.ReLU()
)
self.fc3 = nn.Linear(84, 10)
# 定义前向传播过程,输入为x
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
# nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
#使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路径
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加载路径
opt = parser.parse_args()
# 超参数设置
EPOCH = 8 #遍历数据集次数
BATCH_SIZE = 64 #批处理尺寸(batch_size)
LR = 0.001 #学习率
# 定义数据预处理方式
transform = transforms.ToTensor()
# 定义训练数据集
trainset = tv.datasets.MNIST(
root='./data/',
train=True,
download=True,
transform=transform)
# 定义训练批处理数据
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=BATCH_SIZE,
shuffle=True,
)
# 定义测试数据集
testset = tv.datasets.MNIST(
root='./data/',
train=False,
download=True,
transform=transform)
# 定义测试批处理数据
testloader = torch.utils.data.DataLoader(
testset,
batch_size=BATCH_SIZE,
shuffle=False,
)
# 定义损失函数loss function 和优化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,通常用于多分类问题上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 训练
if __name__ == "__main__":
for epoch in range(EPOCH):
sum_loss = 0.0
# 数据读取
for i, data in enumerate(trainloader):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 梯度清零
optimizer.zero_grad()
# forward + backward
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 每训练100个batch打印一次平均loss
sum_loss += loss.item()
if i % 100 == 99:
print('[%d, %d] loss: %.03f'
% (epoch + 1, i + 1, sum_loss / 100))
sum_loss = 0.0
# 每跑完一次epoch测试一下准确率
with torch.no_grad():
correct = 0
total = 0
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
# 取得分最高的那个类
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('第%d个epoch的识别准确率为:%d%%' % (epoch + 1, (100 * correct / total)))
#torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))
希望本文所述对大家Python程序设计有所帮助。
来源:https://blog.csdn.net/xiangxianghehe/article/details/80719612
0
投稿
猜你喜欢
- 本文涉及:Windows操作系统,Python,PyQt5,Qt Designer,PyCharm一、自适应原理 &
- 1.准备代码# coding=utf-8class TestDebug: def __init__(self):
- 常用Mysql查询语句记录一、授权1.授权本地用户对所有数据库具有所有权限> grant all privileges on
- 如下所示:import arcpy... from arcpy import env... env.workspace="C:\\
- 在继承的使用上,我们最早接触的是父类和子类的继承。不过Flask框架中的继承要简单一些,只要有一个原文件,便可以对其进行继承和修改的操作了。
- 作者: Alan Pearce原文: Multi-Column Layouts Climb Out of the Box地址: http:/
- python是解释型语言,本文介绍了Python下利用turtle实现绘图功能的示例,本例所示为Python绘制一个树枝,具体实现代码如下:
- 做机器学习的一定对支持向量机(support vecto
- 1 引言这段时间在研究美团爬虫,用的是scrapy-redis分布式爬虫框架,奈何scrapy-redis与scrapy框架不同,默认只发送
- 如果在子类中需要父类的构造方法就需要显式地调用父类的构造方法,或者不重写父类的构造方法。子类不重写 __init__,实例化子类时,会自动调
- php中可以使用 mb_detect_encoding() 函数来判断字符串是什么编码的。当在php中使用mb_detect_encodin
- 每个写asp程序人必会的知识!在ASP编程中使用数组数组的定义Dim MyArrayMyArray = Array(1,5,123,12,9
- 如何在本地机器上创建缓存?用法到是很简单,只需先创建Stream对象的实例,然后开始写入数据即可: Dim str&n
- 1、封装的理解封装(Encapsulation):属性和方法的抽象属性的抽象:对类的属性(变量)进行定义、隔离和保护分为私有属性和公开属性:
- 可以在Mac OS X 10.2.x(“Jaguar”)和以上版本上Mac OS X使用二进制安装软
- 导语哈喽吖铁汁萌!今天这期就给大家介绍几个我用到的办公室自动化技巧,可以瞬速提高办公效率。有需要的可以往下滑了1、Word文档doc转doc
- 论坛有人问起如何获取读取CSS属性值,就写了下面这段兼容各浏览器的获取HTML元素的css属性值函数:function getSt
- 互联网上的每台计算机都有独一无二的编号,称为IP地址,每个合法的IP地址由“.”分开的4个数字组成,并且IP地址细分类型的话,可以分为“A”
- 一、opencv是什么?OpenCV是一个用于图像处理、分析、机器视觉方面的开源函数库.二、使用步骤1.引入库代码如下:import cv2
- 前段时间做视频时需要演示电脑端的操作,因此要用到屏幕录制,下载了个迅捷屏幕录制,但是没有vip录制的视频有水印且只能录制二分钟,于是鄙人想了