pytorch cnn 识别手写的字实现自建图片数据
作者:瓦力冫 发布时间:2023-04-18 02:39:22
标签:pytorch,cnn,识别手写
本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:
# library
# standard library
import os
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001 # learning rate
root = "./mnist/raw/"
def default_loader(path):
# return Image.open(path).convert('RGB')
return Image.open(path)
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
fh.close()
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = Image.fromarray(np.array(img), mode='L')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=16, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out(x)
return output, x # return x for visualization
cnn = CNN()
print(cnn) # net architecture
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
b_x = Variable(x) # batch x
b_y = Variable(y) # batch y
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
cnn.eval()
eval_loss = 0.
eval_acc = 0.
for i, (tx, ty) in enumerate(test_loader):
t_x = Variable(tx)
t_y = Variable(ty)
output = cnn(t_x)[0]
loss = loss_func(output, t_y)
eval_loss += loss.data[0]
pred = torch.max(output, 1)[1]
num_correct = (pred == t_y).sum()
eval_acc += float(num_correct.data[0])
acc_rate = eval_acc / float(len(test_data))
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》
结果如下:
来源:http://www.waitingfy.com/archives/3549
0
投稿
猜你喜欢
- python开启debug模式的代码如下所示:import requests session = requests.session()imp
- 学习了vue.js一段时间,拿它来做2个小组件,练习一下。我这边是用webpack进行打包,也算熟悉一下它的运用。源码放在文末的 githu
- 通常,当一个页面有太多信息要显示,而一页塞又不下所有信。为了请求速度、美观以及其他的各种理由,分页就会被我们请过来。让我们的用户可以选择是否
- 收集所有外部链接的网站爬虫程序流程图下例是爬取本站python绘制条形图方法代码详解的实例,大家可以参考下。完整代码:#! /usr/bin
- Vue.js是当下很火的一个JavaScript MVVM库,它是以数据驱动和组件化的思想构建的。相比于Angular.js,Vue.js提
- 1、查找字符串除了使用index()方法在字符串中查找指定元素,还可以使用find()方法在一个较长的字符串中查找子串。如果找到子串,返回子
- 不得不说python的自制包的相关工具真是多且混乱,什么setuptools,什么distutils,什么wheel,什么egg!!怎么有这
- 什么是Inception ResnetV2Inception ResnetV2是Inception ResnetV1的一个加强版,两者的结构
- select * from _test a left join _test b on a.id=b.id where a.level=
- map是key-value数据结构,又称为字段或者关联数组。类似其他编程语言的集合一、基本语法var 变量名 map[keyty
- 在开始部分,请看官非常非常耐心地阅读下面几个枯燥的术语解释,本来这不符合本教程的风格,但是,请看官谅解,因为列位将来一定要阅读枯燥的东西的。
- 链表一种物理存储单元上非连续、非顺序的存储结构,数据元素的逻辑顺序是通过链表中的指针链接次序实现的。链表由一系列结点(链表中每一个元素称为结
- 本篇我们将以分析历史股价为例,介绍怎样从文件中载入数据,以及怎样使用NumPy的基本数学和统计分析函数、学习读写文件的方法,并尝试函数式编程
- 废话不多说,直接上代码<?php // 暂不支持断点续传 // $url = 'http://www.
- 例子:def re_escape(fn): def arg_escaped(this, *args):&
- Python编程语言的优点非常多,它的编程特色主要体现在可扩充性方面。那么,在接下来的这篇文章中,我们将会为大家详细介绍一下有关Python
- 本文实例讲述了JavaScript函数参数使用带参数名的方式赋值传入的方法。分享给大家供大家参考。具体分析如下:这里其实就是在给函数传递参数
- 我就废话不多说啦,还是直接看代码吧!try: print(a)except Exception as e: prin
- 面向过程的程序设计把计算机程序视为一系列的命令集合,即一组函数的顺序执行。为了简化程序设计,面向过程把函数继续切分为子函数,即把大块函数通过
- 用for循环实现1~n求和的方法def main(): sum = 0 n = int(input('n=&