Pytorch框架实现mnist手写库识别(与tensorflow对比)
作者:社会青年技术官 发布时间:2022-07-30 00:41:42
标签:Pytorch,mnist,手写,识别
前言最近在学习过程中需要用到pytorch框架,简单学习了一下,写了一个简单的案例,记录一下pytorch中搭建一个识别网络基础的东西。对应一位博主写的tensorflow的识别mnist数据集,将其改为pytorch框架,也可以详细看到两个框架大体的区别。
Tensorflow版本转载来源(CSDN博主「兔八哥1024」):https://www.jb51.net/article/191157.htm
Pytorch实战mnist手写数字识别
#需要导入的包
import torch
import torch.nn as nn#用于构建网络层
import torch.optim as optim#导入优化器
from torch.utils.data import DataLoader#加载数据集的迭代器
from torchvision import datasets, transforms#用于加载mnsit数据集
#下载数据集
train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1037,), (0.3081,))
]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1037,), (0.3081,))
]))
#构建网络(网络结构对应tensorflow的那一篇文章)
class Net(nn.Module):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2,stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(3136, 7*7*64),
nn.Linear(3136, num_classes),
)
def forward(self,x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
net=Net()
net.cuda()#用GPU运行
#计算误差,使用adam优化器优化误差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
#训练过程
for epoch in range(1):
net.train() ##在进行训练时加上train(),测试时加上eval()
batch = 0
for batch_images, batch_labels in train_data:
average_loss = 0
train_acc = 0
##在pytorch0.4之后将Variable 与tensor进行合并,所以这里不需要进行Variable封装
if torch.cuda.is_available():
batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()
#前向传播
out = net(batch_images)
loss = criterion(out,batch_labels)
average_loss = loss
prediction = torch.max(out,1)[1]
# print(prediction)
train_correct = (prediction == batch_labels).sum()
##这里得到的train_correct是一个longtensor型,需要转换为float
train_acc = (train_correct.float()) / 128
optimizer.zero_grad() #清空梯度信息,否则在每次进行反向传播时都会累加
loss.backward() #loss反向传播
optimizer.step() ##梯度更新
batch+=1
print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
%(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))
# 在测试集上检验效果
net.eval() # 将模型改为预测模式
for idx,(im1, label1) in enumerate(test_data):
if torch.cuda.is_available():
im, label = im1.cuda(),label1.cuda()
out = net(im)
loss = criterion(out, label)
eval_loss = loss
pred = torch.max(out,1)[1]
num_correct = (pred == label).sum()
acc = (num_correct.float())/ 128
eval_acc = acc
print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
.format(idx,eval_loss , eval_acc))
运行结果:
来源:https://juejin.im/post/5f143a446fb9a07ec172ec88
0
投稿
猜你喜欢
- 最近,我喜欢上了XML编程,但又苦于它的美观程度又不够,找了许多书才搞定。 &n
- 项目说明 该电商项目类似于京东商城,主要模块有验证、用户、第三方登录、首页广告、商品、购物车、订单、支付以及后台管理系统。项目开发模式采用前
- 在研究ezSQL的时候就看到了mssql_connect()等一些php提供的连接MSSQL的函数,本以为php这个开源的风靡世界的编程语言
- 你还没用 jQuery 写过导航菜单? 相信看到这些出色的jQuery导航菜单后,一定会为此而后悔没早点把 jQuery 应用到自己的Web
- python记录程序运行时间的三种方法 &nb
- 个人网站如有会员注册模块+动网论坛的话,那网站要与动网论坛系统整合,实现不同Web系统之间的用户信息同步更新、登录等操作就不是件容易的事了,
- 包括安装时提示有挂起的操作、收缩数据库、压缩数据库、转移数据库给新用户以已存在用户权限、检查备份集、修复数据库等。 (一)挂起操作在安装S
- 先举个例子,分别以不指定编码、指定编码为 utf-8、指定编码为 utf-8-sig 三种方式来做比较,再将写入 csv 文件和 txt 文
- 我们先看一下JavaScript中关系运算符的类型转换规则:关系运算符(<、>、<=、>=) 试图将 express
- Laravel 中间件提供了一种方便的机制来过滤进入应用的 HTTP 请求。例如,Laravel 内置了一个中间件来验证用户的身份认证。如果
- 今天我要为大家介绍的是XPath,XPath是导航和查询XML文档的语言。我们从一个函数开始。UpdateXML()函数我们已经花了很多时间
- 准备开始学习Python,但是刚准备环境搭建时就遇到了下面的错误:仔细的看了看,说是缺少DLL。对于这个问题的解决办法:方法一:1. 在安装
- /* 小弟刚刚接触ORACLE存储过程,有一个问题向各位同行求教,小弟写了一个存储过程,其目的是接收一个参数作为表名,然后查询该表中的全部记
- 随着手机用户的不断增加,WAP站点如雨后春笋迅速的滋长开来,手机邮箱也不断的出现在人的眼前,笔者也曾经开发了一套手机邮箱的系统,但由于时间仓
- 数据库镜像是将数据库事务处理从一个数据库移动到不同环境中的另一个数据库中。镜像的拷贝是一个备用的拷贝,不能直接访问,它只用在错误恢复的情况下
- 首先,必须有错误继续进行的声明On Error Resume Next 然后尝试简历jmail实例: Dim JMail Set JMail
- PDO::inTransactionPDO::inTransaction — 检查是否在一个事务内(PHP 5 >= 5.3.3, B
- tkinter禁用(只读)下拉列表Comboboxtkinter将下拉列表框Combobox控件的状态设置为只读,也就是不可编辑状态:# 定
- 之前在《首都机场的点烟器》中分析了一个软件系统所处的状态并且列举了不同的状态所需要的展示给用户的各类信息,我们先简单回顾一下:要设计一个软件
- 一.基于纹理背景的图像分割该部分主要讲解基于图像纹理信息(颜色)、边界信息(反差)和背景信息的图像分割算法。在OpenCV中,GrabCut