kNN算法python实现和简单数字识别的方法
作者:shichen2014 发布时间:2023-09-05 21:44:36
本文实例讲述了kNN算法python实现和简单数字识别的方法。分享给大家供大家参考。具体如下:
kNN算法算法优缺点:
优点:精度高、对异常值不敏感、无输入数据假定
缺点:时间复杂度和空间复杂度都很高
适用数据范围:数值型和标称型
算法的思路:
KNN算法(全称K最近邻算法),算法的思想很简单,简单的说就是物以类聚,也就是说我们从一堆已知的训练集中找出k个与目标最靠近的,然后看他们中最多的分类是哪个,就以这个为依据分类。
函数解析:
库函数:
tile()
如tile(A,n)就是将A重复n次
a = np.array([0, 1, 2])
np.tile(a, 2)
array([0, 1, 2, 0, 1, 2])
np.tile(a, (2, 2))
array([[0, 1, 2, 0, 1, 2],[0, 1, 2, 0, 1, 2]])
np.tile(a, (2, 1, 2))
array([[[0, 1, 2, 0, 1, 2]],[[0, 1, 2, 0, 1, 2]]])
b = np.array([[1, 2], [3, 4]])
np.tile(b, 2)
array([[1, 2, 1, 2],[3, 4, 3, 4]])
np.tile(b, (2, 1))
array([[1, 2],[3, 4],[1, 2],[3, 4]])`
自己实现的函数
createDataSet()生成测试数组
kNNclassify(inputX, dataSet, labels, k)分类函数
inputX 输入的参数
dataSet 训练集
labels 训练集的标号
k 最近邻的数目
#coding=utf-8
from numpy import *
import operator
def createDataSet():
group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A','A','B','B']
return group,labels
#inputX表示输入向量(也就是我们要判断它属于哪一类的)
#dataSet表示训练样本
#label表示训练样本的标签
#k是最近邻的参数,选最近k个
def kNNclassify(inputX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]#计算有几个训练数据
#开始计算欧几里得距离
diffMat = tile(inputX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
distances = sqDistances ** 0.5
#欧几里得距离计算完毕
sortedDistance = distances.argsort()
classCount = {}
for i in xrange(k):
voteLabel = labels[sortedDistance[i]]
classCount[voteLabel] = classCount.get(voteLabel,0) + 1
res = max(classCount)
return res
def main():
group,labels = createDataSet()
t = kNNclassify([0,0],group,labels,3)
print t
if __name__=='__main__':
main()
kNN应用实例
手写识别系统的实现
数据集:
两个数据集:training和test。分类的标号在文件名中。像素32*32的。数据大概这个样子:
方法:
kNN的使用,不过这个距离算起来比较复杂(1024个特征),主要是要处理如何读取数据这个问题的,比较方面直接调用就可以了。
速度:
速度还是比较慢的,这里数据集是:training 2000+,test 900+(i5的CPU)
k=3的时候要32s+
#coding=utf-8
from numpy import *
import operator
import os
import time
def createDataSet():
group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
labels = ['A','A','B','B']
return group,labels
#inputX表示输入向量(也就是我们要判断它属于哪一类的)
#dataSet表示训练样本
#label表示训练样本的标签
#k是最近邻的参数,选最近k个
def kNNclassify(inputX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]#计算有几个训练数据
#开始计算欧几里得距离
diffMat = tile(inputX, (dataSetSize,1)) - dataSet
#diffMat = inputX.repeat(dataSetSize, aixs=1) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)#矩阵每一行向量相加
distances = sqDistances ** 0.5
#欧几里得距离计算完毕
sortedDistance = distances.argsort()
classCount = {}
for i in xrange(k):
voteLabel = labels[sortedDistance[i]]
classCount[voteLabel] = classCount.get(voteLabel,0) + 1
res = max(classCount)
return res
def img2vec(filename):
returnVec = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVec[0,32*i+j] = int(lineStr[j])
return returnVec
def handwritingClassTest(trainingFloder,testFloder,K):
hwLabels = []
trainingFileList = os.listdir(trainingFloder)
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
fileName = trainingFileList[i]
fileStr = fileName.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vec(trainingFloder+'/'+fileName)
testFileList = os.listdir(testFloder)
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileName = testFileList[i]
fileStr = fileName.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vec(testFloder+'/'+fileName)
classifierResult = kNNclassify(vectorUnderTest, trainingMat, hwLabels, K)
#print classifierResult,' ',classNumStr
if classifierResult != classNumStr:
errorCount +=1
print 'tatal error ',errorCount
print 'error rate',errorCount/mTest
def main():
t1 = time.clock()
handwritingClassTest('trainingDigits','testDigits',3)
t2 = time.clock()
print 'execute ',t2-t1
if __name__=='__main__':
main()
希望本文所述对大家的Python程序设计有所帮助。


猜你喜欢
- 一、下载镜像docker Hub官网URL:https://hub.docker.com/_/mysql/下载最新版本:docker pul
- 本文实例讲述了Python操作MongoDB数据库的方法。分享给大家供大家参考,具体如下:>>> import pymon
- 前言本章介绍pandas中的缺失数据,主要内容有:pandas中对np.nan的操作: 统计 、 删除 、 填充 、 插值 pan
- 1.网页背景色的设置 犯错机率:很大普遍性:较广犯错可能性:懒/不知道约2年前我曾发现21cn上出现过一次没有设置背景色的情况,当时我用Em
- 1.导入matplotlib.pylab和numpy包import matplotlib.pylab as pltimport numpy
- 数据库开发数据库应用,选择一个好的数据库是非常重要的。下面从一些方面比较了SQL Server与Oracle、DB2三种数据库,为你选择数据
- 在登陆界面中,通常,最重要的部分为登陆的Form表。一个非常棒的提升体验的做法是,在载入页面时自动聚焦到第一个提供用户输入的表单框,让用户不
- 什么是字符串格式化,为什么需要这样做?我们有时候刷抖音/B站看到封面很好看,但是进入直播发现,不过如此!想必主播通过某种方式把输出转换为读者
- ...mapstate和...mapgetters的区别…mapstate当一个组件需要获取多个状态时候,将这些状态都
- 本文实例讲述了pymongo为mongodb数据库添加索引的方法。分享给大家供大家参考。具体实现方法如下:from pymongo impo
- 引子编程世界里只存在两种基本元素,一个是数据,一个是代码。编程世界就是在数据和代码千丝万缕的纠缠中呈现出无限的生机和活力。数据天生就是文静的
- 没有,用case when 来代替就行了. 例如,下面的语句显示中文年月 select getdate() as 日期,case month
- 1.背景在python运行一些,计算复杂度比较高的函数时,服务器端单核CPU的情况比较耗时,因此需要多CPU使用多进程加快速度2.函数要求笔
- T-SQL 标识符在T-SQL语言中,对SQLServer数据库及其数据对象(比如表、索引、视图、存储过程、触发器等)需要以名称来进行命名并
- 今天在使用PyTorch中Dataset遇到了一个问题。先看代码class psDataset(Dataset): def __
- 抢票是并发执行多个进程可以访问同一个文件多个进程共享同一文件,我们可以把文件当数据库,用多个进程模拟多个人执行抢票任务db.tx
- 本文主要记录了在Nodejs开发过程中遇到过的由数组特性引起的问题及解决方式,以及对数组的灵活应用。本文代码测试结果均基于node v6.9
- 关于元组,上一讲中涉及到了这个名词。本讲完整地讲述它。先看一个例子:>>>#变量引用str>>> s =
- 本文实例为大家分享了JS作用域链的相关内容,供大家参考,具体内容如下1、所有全局变量和函数都是作为window对象的属性和方法创建的。2、在
- MySQL 默认有个root用户,但是这个用户权限太大,一般只在管理数据库时候才用。如果在项目中要连接 MySQL 数据库,则建议新建一个权