pytorch cnn 识别手写的字实现自建图片数据
作者:瓦力冫 发布时间:2023-04-18 02:39:22
标签:pytorch,cnn,识别手写
本文主要介绍了pytorch cnn 识别手写的字实现自建图片数据,分享给大家,具体如下:
# library
# standard library
import os
# third-party library
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
# torch.manual_seed(1) # reproducible
# Hyper Parameters
EPOCH = 1 # train the training data n times, to save time, we just train 1 epoch
BATCH_SIZE = 50
LR = 0.001 # learning rate
root = "./mnist/raw/"
def default_loader(path):
# return Image.open(path).convert('RGB')
return Image.open(path)
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
fh.close()
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
img = Image.fromarray(np.array(img), mode='L')
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data = MyDataset(txt= root + 'train.txt', transform = torchvision.transforms.ToTensor())
train_loader = DataLoader(dataset = train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = MyDataset(txt= root + 'test.txt', transform = torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset = test_data, batch_size=BATCH_SIZE)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (1, 28, 28)
nn.Conv2d(
in_channels=1, # input height
out_channels=16, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2, # if want same width and length of this image after con2d, padding=(kernel_size-1)/2 if stride=1
), # output shape (16, 28, 28)
nn.ReLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 14, 14)
)
self.conv2 = nn.Sequential( # input shape (16, 14, 14)
nn.Conv2d(16, 32, 5, 1, 2), # output shape (32, 14, 14)
nn.ReLU(), # activation
nn.MaxPool2d(2), # output shape (32, 7, 7)
)
self.out = nn.Linear(32 * 7 * 7, 10) # fully connected layer, output 10 classes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7)
output = self.out(x)
return output, x # return x for visualization
cnn = CNN()
print(cnn) # net architecture
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss() # the target label is not one-hotted
# training and testing
for epoch in range(EPOCH):
for step, (x, y) in enumerate(train_loader): # gives batch data, normalize x when iterate train_loader
b_x = Variable(x) # batch x
b_y = Variable(y) # batch y
output = cnn(b_x)[0] # cnn output
loss = loss_func(output, b_y) # cross entropy loss
optimizer.zero_grad() # clear gradients for this training step
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
if step % 50 == 0:
cnn.eval()
eval_loss = 0.
eval_acc = 0.
for i, (tx, ty) in enumerate(test_loader):
t_x = Variable(tx)
t_y = Variable(ty)
output = cnn(t_x)[0]
loss = loss_func(output, t_y)
eval_loss += loss.data[0]
pred = torch.max(output, 1)[1]
num_correct = (pred == t_y).sum()
eval_acc += float(num_correct.data[0])
acc_rate = eval_acc / float(len(test_data))
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(test_data)), acc_rate))
图片和label 见上一篇文章《pytorch 把MNIST数据集转换成图片和txt》
结果如下:
来源:http://www.waitingfy.com/archives/3549
0
投稿
猜你喜欢
- 图片显示pytorch 载入的数据集是元组tuple 形式,里面包括了数据及标签(train_data,label),其中的train_da
- asp之家注:有时候我们需要知道我们链接的远程图片是否正常,是否存在,当不存在时如果我们继续引用,就会在网页上留个大大的X,影响了页面美观。
- 原图代码 src = cv2.imread("28.png") gray_src = cv2.c
- reload() 简介作用:用于重新载入之前载入的模块语法格式:reload(module)参数:module为模块对象,必须已经被加载返回
- 一、实现代码1.sql-- phpMyAdmin SQL Dump-- version 4.5.1-- http://www.phpmyad
- 前言在设计爬虫项目的时候,首先要在脑内明确人工浏览页面获得图片时的步骤一般地,我们去网上批量打开壁纸的时候一般操作如下:1、打开壁纸网页2、
- 神奇创意相框! 是的,主要利用position的relative, absolute, z-index属性。结合Photo Frame(相框
- 该平台会集成UI自动化及api自动化,里面也会涉及到一些简单的HTML等前端,当然都是很基础的东西。在以后的博客里,我会一点点的尽量写详细,
- 本文实例为大家分享了python定时发送邮件的具体代码,供大家参考,具体内容如下全部代码如下:import timefrom datetim
- 由于可将 Microsoft? SQL Server? 2000 设置为包含一个或多个命名实例和一个默认实例(也可无),所以要用新命名规则来
- 在项目中,尤其是pc端的时候,我们在用户登录后会给前端返回一个标识,来判断用户是否登录,这个标识大多数都是用户的id  
- 上一篇:微软建议的ASP性能优化28条守则(7)技巧 22:尽可能使用 Server.Transfer 代替 Response.Redire
- Serilog是.net下的新兴的日志框架,本文这里简单的介绍一下它的用法。首先安装Nuget包:Install-Package Seril
- 循环语句是一种常用的控制结构,在 Go 语言中,除了 for 关键字以外,还有一个 range 关键
- 非常好的边框样式设置工具,使用该工具您可以很方便的为DIV设置简单的边框样式,如果放在DW中会更好。会制作DW插件的高手,请帮忙制作成DW插
- 周末在家,儿子闹着要玩游戏,让玩吧,不利于健康,不让玩吧,扛不住他折腾,于是想,不如一起搞个小游戏玩玩!之前给他编过猜数字 和 掷骰子 游戏
- 这篇文章主要介绍了python Opencv计算图像相似度过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 本文实例为大家分享了Python密码强弱判断的具体代码,供大家参考,具体内容如下程序说明:通过获取用户输入,判断密码长度是否大于8,同时判断
- 也许你听说过Hibernate的大名,但可能一直不了解它,也许你一直渴望使用它进行开发,那么本文正是你所需要的!在本文中,我向大家重点介绍H
- 一、pexpect模块介绍Pexpect使Python成为控制其他应用程序的更好工具。可以理解为Linux下的expect的Python封装