pytorch 利用lstm做mnist手写数字识别分类的实例
作者:xckkcxxck 发布时间:2023-01-31 03:15:38
标签:pytorch,lstm,mnist,手写,识别
代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。
# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
import sys
sys.path.append('..')
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
#定义数据
data_tf = tfs.Compose([
tfs.ToTensor(),
tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
#定义模型
class rnn_classify(nn.Module):
def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
super(rnn_classify, self).__init__()
self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
def forward(self, x):
#x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
out = self.classifier(out) #得到分类结果
return out
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
#定义训练过程
def get_acc(output, label):
total = output.shape[0]
_, pred_label = output.max(1)
num_correct = (pred_label == label).sum().item()
return num_correct / total
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
if torch.cuda.is_available():
net = net.cuda()
prev_time = datetime.datetime.now()
for epoch in range(num_epochs):
train_loss = 0
train_acc = 0
net = net.train()
for im, label in train_data:
if torch.cuda.is_available():
im = Variable(im.cuda()) # (bs, 3, h, w)
label = Variable(label.cuda()) # (bs, h, w)
else:
im = Variable(im)
label = Variable(label)
# forward
output = net(im)
loss = criterion(output, label)
# backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_acc += get_acc(output, label)
cur_time = datetime.datetime.now()
h, remainder = divmod((cur_time - prev_time).seconds, 3600)
m, s = divmod(remainder, 60)
time_str = "Time %02d:%02d:%02d" % (h, m, s)
if valid_data is not None:
valid_loss = 0
valid_acc = 0
net = net.eval()
for im, label in valid_data:
if torch.cuda.is_available():
im = Variable(im.cuda())
label = Variable(label.cuda())
else:
im = Variable(im)
label = Variable(label)
output = net(im)
loss = criterion(output, label)
valid_loss += loss.item()
valid_acc += get_acc(output, label)
epoch_str = (
"Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
% (epoch, train_loss / len(train_data),
train_acc / len(train_data), valid_loss / len(valid_data),
valid_acc / len(valid_data)))
else:
epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
(epoch, train_loss / len(train_data),
train_acc / len(train_data)))
prev_time = cur_time
print(epoch_str + time_str)
train(net, train_data, test_data, 10, optimizer, criterion)
来源:https://blog.csdn.net/xckkcxxck/article/details/82978942
0
投稿
猜你喜欢
- torch.nn.Conv2d中自定义权重torch.nn.Conv2d函数调用后会自动初始化weight和bias,本文主要涉及如何自定义
- Asp中Server.ScriptTimeOut属性需要注意的一点Server.ScriptTimeout 这个属性给定Asp脚
- 问题你的包中包含代码需要去读取的数据文件。你需要尽可能地用最便捷的方式来做这件事。解决方案假设你的包中的文件组织成如下:mypackage/
- 内容概要:print() 是一个常用函数。那么,您是否注意过,print() 会在显示当前语句后换行。如果遇到需要连续显示、不换行的情况,比
- 最近的答题赢钱很火爆,我也参与了几次,有些题目确实很难答,但是10秒钟的时间根本不够百度的,所以写了个辅助挂,这样可以出现题目时自动百度,这
- 1、requests 的常见用法requests 除了 url 之外,还有 params, data 和 files 三个参数,用于和服务器
- 核心提示:VB读取MP3文件帧的信息比特率,采样频率,播放时间Private Sub Command1_Click()On Error Go
- Python中的闭包前几天又有人留言,关于其中一个闭包和re.sub的使用不太清楚。我在脚本之家搜索了下,发现没有写过闭包相关的东西,所以决
- 大家好,今天才发现很多学习Flask的小伙伴都有这么一个问题,清理缓存好麻烦啊,今天就教大家怎么解决。大家在使用Flask静态文件的时候,每
- 很多组织机构慢慢的在不同的服务器和地点部署SQL Server数据库——为各种应用和目的&m
- 如何用ASP发送带附件的邮件?请问如何用CDONTS组件发送带附件的邮件? 见下列代码:<%&nb
- Python 内置的 itertools 模块包含了一系列用来产生不同类型迭代器的函数或类,这些函数的返回都是一个迭代器,我们可以通过 fo
- explain显示了mysql如何使用索引来处理select语句以及连接表.可以帮助选择更好的索引和写出更优化的查询语句.使用方法:在sel
- 一.Sobel算子Sobel算子是一种用于边缘检测的离散微分算子,它结合了高斯平滑和微分求导。该算子用于计算图像明暗程度近似值,根据图像边缘
- max(iterable, *[, key, default])max(arg1, arg2, *args[, key])函数功能为取传入的
- what's the math 模块Python math 模块提供了许多对浮点数的数学运算函数。需要注意的是,这些函数一般是对平台
- 很久没有写文章,最近一直在忙于找工作和找房子。哎,现在终于安定下来了,哎,又叹息一下,是因为我把去淘宝面试的机会也推掉了,本来以为要卷铺盖回
- 1. tensorflow模型文件打包成PB文件import tensorflow as tffrom tensorflow.python.
- 前言自动化测试中我们存放数据无非是使用文件或者数据库,那么文件可以是csv,xlsx,xml,甚至是txt文件,通常excel文件往往是我们
- 定义和用法fopen() 函数打开文件或者 URL。如果打开失败,本函数返回 FALSE。语法fopen(filename,mode,inc