PyTorch实现手写数字的识别入门小白教程
作者:B.Bz 发布时间:2021-02-04 19:58:59
标签:PyTorch,手写,数字,识别
手写数字识别(小白入门)
今早刚刚上了节实验课,关于逻辑回归,所以手有点刺挠就想发个博客,作为刚刚入门的小白,看到代码运行成功就有点小激动,这个实验没啥含金量,所以路过的大牛不要停留,我怕你们吐槽哈哈。
实验结果:
1.数据预处理
其实呢,原理很简单,就是使用多变量逻辑回归,将训练28*28图片的灰度值转换成一维矩阵,这就变成了求784个特征向量1个标签的逻辑回归问题。代码如下:
#数据预处理
trainData = np.loadtxt(open('digits_training.csv', 'r'), delimiter=",",skiprows=1)#装载数据
MTrain, NTrain = np.shape(trainData) #行列数
print("训练集:",MTrain,NTrain)
xTrain = trainData[:,1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) #对各列求均值
xTrain =(xTrain- xTrain_col_avg)/255 #归一化
yTrain = trainData[:,0]
2.训练模型
对于数学差的一批的我来说,学习算法真的是太太太扎心了,好在具体算法封装在了sklearn库中。简单两行代码即可完成。具体参数的含义随随便便一搜到处都是,我就不班门弄斧了,每次看见算法除了头晕啥感觉没有。
model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)
model.fit(xTrain, yTrain)
3.测试模型,保存
接下来测试一下模型,准确率能达到百分之90,也不算太高,训练数据集本来也不是很多。
为了方便,所以把模型保存下来,不至于运行一次就得训练一次。
#测试模型
testData = np.loadtxt(open('digits_testing.csv', 'r'), delimiter=",",skiprows=1)
MTest,NTest = np.shape(testData)
print("测试集:",MTest,NTest)
xTest = testData[:,1:NTest]
xTest = (xTest-xTrain_col_avg) /255 # 使用训练数据的列均值进行处理
yTest = testData[:,0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) #返回非零项个数
print("预测完毕。错误:", errors, "条")
print("测试数据正确率:", (MTest - errors) / MTest)
'''================================='''
#保存模型
# 创建文件目录
dirs = 'testModel'
if not os.path.exists(dirs):
os.makedirs(dirs)
joblib.dump(model, dirs+'/model.pkl')
print("模型已保存")
https://download.csdn.net/download/qq_45874897/12427896 需要的可以自行下载
4.调用模型
既然模型训练好了,就来放几张图片调用模型试一下看看怎么样
导入要测试的图片,然后更改大小为28*28,将图片二值化减小误差。
为了让结果看起来有逼格,所以最后把图片和识别数字同实显示出来。
import cv2
import numpy as np
from sklearn.externals import joblib
map=cv2.imread(r"C:\Users\lenovo\Desktop\[DX6@[C$%@2RS0R2KPE[W@V.png")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret,thresh2=cv2.threshold(GrayImage,127,255,cv2.THRESH_BINARY_INV)
Image=cv2.resize(thresh2,(28,28))
img_array = np.asarray(Image)
z=img_array.reshape(1,-1)
'''================================================'''
model = joblib.load('testModel'+'/model.pkl')
yPredict = model.predict(z)
print(yPredict)
y=str(yPredict)
cv2.putText(map,y, (10,20), cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255), 2, cv2.LINE_AA)
cv2.imshow("map",map)
cv2.waitKey(0)
5.完整代码
test1.py
import numpy as np
from sklearn.linear_model import LogisticRegression
import os
from sklearn.externals import joblib
#数据预处理
trainData = np.loadtxt(open('digits_training.csv', 'r'), delimiter=",",skiprows=1)#装载数据
MTrain, NTrain = np.shape(trainData) #行列数
print("训练集:",MTrain,NTrain)
xTrain = trainData[:,1:NTrain]
xTrain_col_avg = np.mean(xTrain, axis=0) #对各列求均值
xTrain =(xTrain- xTrain_col_avg)/255 #归一化
yTrain = trainData[:,0]
'''================================='''
#训练模型
model = LogisticRegression(solver='lbfgs', multi_class='multinomial', max_iter=500)
model.fit(xTrain, yTrain)
print("训练完毕")
'''================================='''
#测试模型
testData = np.loadtxt(open('digits_testing.csv', 'r'), delimiter=",",skiprows=1)
MTest,NTest = np.shape(testData)
print("测试集:",MTest,NTest)
xTest = testData[:,1:NTest]
xTest = (xTest-xTrain_col_avg) /255 # 使用训练数据的列均值进行处理
yTest = testData[:,0]
yPredict = model.predict(xTest)
errors = np.count_nonzero(yTest - yPredict) #返回非零项个数
print("预测完毕。错误:", errors, "条")
print("测试数据正确率:", (MTest - errors) / MTest)
'''================================='''
#保存模型
# 创建文件目录
dirs = 'testModel'
if not os.path.exists(dirs):
os.makedirs(dirs)
joblib.dump(model, dirs+'/model.pkl')
print("模型已保存")
运行结果
test2.py
import cv2
import numpy as np
from sklearn.externals import joblib
map=cv2.imread(r"C:\Users\lenovo\Desktop\[DX6@[C$%@2RS0R2KPE[W@V.png")
GrayImage = cv2.cvtColor(map, cv2.COLOR_BGR2GRAY)
ret,thresh2=cv2.threshold(GrayImage,127,255,cv2.THRESH_BINARY_INV)
Image=cv2.resize(thresh2,(28,28))
img_array = np.asarray(Image)
z=img_array.reshape(1,-1)
'''================================================'''
model = joblib.load('testModel'+'/model.pkl')
yPredict = model.predict(z)
print(yPredict)
y=str(yPredict)
cv2.putText(map,y, (10,20), cv2.FONT_HERSHEY_SIMPLEX,0.7,(0,0,255), 2, cv2.LINE_AA)
cv2.imshow("map",map)
cv2.waitKey(0)
提供几张样本用来测试:
实验中还有很多地方需要优化,比如数据集太少,泛化能力太差,用样本的数据测试正确率挺高,但是用我自己手写的字正确率就太低了,可能我字写的太丑,哎,还是自己太菜了,以后得多学学算法了。
来源:https://blog.csdn.net/bjsyc123456/article/details/125042613
0
投稿
猜你喜欢
- 从某个页面表单中取出信息是ASP编程中常见的问题。但是,遍历通过表单传递的记录会花去多长时间呢?这取决于数据库的大小。简单的GUI界面都可能
- 3. 品味“决定”艺术作品的好坏,设计的好坏则来自主观意见我们在鉴赏艺术作品时,用看法来表达当时的感觉,而你的品味则会左右你的看法。以一个有
- SVG是XML来描述二维图形的语言。SVG可以构造3种类型的图形对象:矢量图形、位图图象和文字。图形对象可被组化、样式化、变形和重组,包括图
- 如下所示:import re# 过滤不了\\ \ 中文()还有————r1 = u'[a-zA-Z0-9'!"#$
- 一、SQLAlchemy简介1.1、SQLAlchemy是什么?sqlalchemy是一个python语言实现的的针对关系型数据库的orm库
- 下面是BeforeInitialBind事件过程:<SCRIPT language=vbscript event=
- 好了,下面我们看看如何在服务器上生成.m3u文件并下传到客户端的:<%dim choose,path,mydb,myset,
- 模糊数据库指能够处理模糊数据的数据库。一般的数据库都是以二直逻辑和精确的数据工具为基础的,不能表示许多模糊不清的事情。随着模糊数学理论体系的
- 根据一般做法的话,导出部分字段时没有办法生成格式化XML文件,所以导入时就没有办法格式化导入数据。 我想到两点,1.手工修改格式化XML文件
- 数据共享是数据库最基本的特征之一。但是数据共享虽然为员工带来了便利,但也产生了一些负面作用。例如因用户并发存取而导致的对数据一致性的破坏、由
- 如下所示:# coding:utf-8import shapefilew = shapefile.Writer()w.autoBalance
- string模块可以追溯到早期版本的Python。以前在本模块中实现的许多功能已经转移到str物品。这个string模块保留了几个有用的常量
- 啥也不说了,眼泪哗哗的 –来自怨念深重的不灵狗。【运行环境】1、在ubuntu下使用pip安装flask-mongoengine;2、pip
- 下载编译器protoc两种方式:1、使用google官方protoc下载地址:https://github.com/google/proto
- 当我们导入的模型含有自定义层或者自定义函数时,需要使用custom_objects来指定目标层或目标函数。例如:我的一个模型含有自定义层“S
- 目前防采集的方法有很多种,先介绍一下常见防采集策略方法和它的弊端及采集对策: 一、判断一个IP在一定时间内对本站页面的访问次数,如果明显超过
- 前言 本篇章主要介绍二叉树的应用之一------二叉排序树,包括二叉排序树的定义、查找、插入、构造、删除及查找效率分析。1. 二叉排序树的
- 一、安装selenium打开命令控制符输入:pip install -U selenium火狐浏览器安装firebug:www.firebu
- Python数据类型分为值类型和引用类型, 下面我们看下它们的区别:值类型:对象本身不允许修改,数值的修改实际上是让变量指向了一个新的对象包
- 前言:record类型,这是一种新引用类型,而不是类或结构。record与类不同,区别在于record类型使用基于值的相等性。例如:publ