使用Python处理KNN分类算法的实现代码
作者:阿里机器智能 发布时间:2023-11-03 07:03:07
简介: 我们在这世上,选择什么就成为什么,人生的丰富多彩,得靠自己成就。你此刻的付出,决定了你未来成为什么样的人,当你改变不了世界,你还可以改变自己。
KNN分类算法的介绍
KNN分类算法(K-Nearest-Neighbors Classification),又叫K近邻算法,是一个概念极其简单,而分类效果又很优秀的分类算法。
他的核心思想就是,要确定测试样本属于哪一类,就寻找所有训练样本中与该测试样本“距离”最近的前K个样本,然后看这K个样本大部分属于哪一类,那么就认为这个测试样本也属于哪一类。简单的说就是让最相似的K个样本来投票决定。
这里所说的距离,一般最常用的就是多维空间的欧式距离。这里的维度指特征维度,即样本有几个特征就属于几维。
KNN示意图如下所示。(图片来源:百度百科)
上图中要确定测试样本绿色属于蓝色还是红色。
显然,当K=3时,将以1:2的投票结果分类于红色;而K=5时,将以3:2的投票结果分类于蓝色。
KNN算法简单有效,但没有优化的暴力法效率容易达到瓶颈。如样本个数为N,特征维度为D的时候,该算法时间复杂度呈O(DN)增长。
所以通常KNN的实现会把训练数据构建成K-D Tree(K-dimensional tree),构建过程很快,甚至不用计算D维欧氏距离,而搜索速度高达O(D*log(N))。
不过当D维度过高,会产生所谓的”维度灾难“,最终效率会降低到与暴力法一样。
因此通常D>20以后,最好使用更高效率的Ball-Tree,其时间复杂度为O(D*log(N))。
人们经过长期的实践发现KNN算法虽然简单,但能处理大规模的数据分类,尤其适用于样本分类边界不规则的情况。最重要的是该算法是很多高级机器学习算法的基础。
当然,KNN算法也存在一切问题。比如如果训练数据大部分都属于某一类,投票算法就有很大问题了。这时候就需要考虑设计每个投票者票的权重了。
测试数据
测试数据的格式仍然和前面使用的身高体重数据一致。不过数据稍微增加了一些
1.5 40 thin
1.5 50 fat
1.5 60 fat
1.6 40 thin
1.6 50 thin
1.6 60 fat
1.6 70 fat
1.7 50 thin
1.7 60 thin
1.7 70 fat
1.7 80 fat
1.8 60 thin
1.8 70 thin
1.8 80 fat
1.8 90 fat
1.9 80 thin
1.9 90 fat
Python代码实现
scikit-learn提供了优秀的KNN算法支持。
import numpy as np
from sklearn import neighbors
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt
''' 数据读入 '''
data = []
labels = []
with open("data\\1.txt") as ifile:
for line in ifile:
tokens = line.strip().split(' ')
data.append([float(tk) for tk in tokens[:-1]])
labels.append(tokens[-1])
x = np.array(data)
labels = np.array(labels)
y = np.zeros(labels.shape)
''' 标签转换为0/1 '''
y[labels=='fat']=1
''' 拆分训练数据与测试数据 '''
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size = 0.2)
''' 创建网格以方便绘制 '''
h = .01
x_min, x_max = x[:, 0].min() - 0.1, x[:, 0].max() + 0.1
y_min, y_max = x[:, 1].min() - 1, x[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
''' 训练KNN分类器 '''
clf = neighbors.KNeighborsClassifier(algorithm='kd_tree')
clf.fit(x_train, y_train)
'''测试结果的打印'''
answer = clf.predict(x)
print(x)
print(answer)
print(y)
print(np.mean( answer == y))
'''准确率与召回率'''
precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train))
answer = clf.predict_proba(x)[:,1]
print(classification_report(y, answer, target_names = ['thin', 'fat']))
''' 将整个测试空间的分类结果用不同颜色区分开'''
answer = clf.predict_proba(np.c_[xx.ravel(), yy.ravel()])[:,1]
z = answer.reshape(xx.shape)
plt.contourf(xx, yy, z, cmap=plt.cm.Paired, alpha=0.8)
''' 绘制训练样本 '''
plt.scatter(x_train[:, 0], x_train[:, 1], c=y_train, cmap=plt.cm.Paired)
plt.xlabel(u'身高')
plt.ylabel(u'体重')
plt.show()
结果分析
输出结果:
[ 0. 0. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.]
[ 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 0. 1. 1. 0. 1.]
准确率=0.94, score=0.94
precision recall f1-score support
thin 0.89 1.00 0.94 8
fat 1.00 0.89 0.94 9
avg / total 0.95 0.94 0.94 17
KNN分类器在众多分类算法中属于最简单的之一,需要注意的地方不多。有这几点要说明:
1、KNeighborsClassifier可以设置3种算法:‘brute',‘kd_tree',‘ball_tree'。如果不知道用哪个好,设置‘auto'让KNeighborsClassifier自己根据输入去决定。
2、注意统计准确率时,分类器的score返回的是计算正确的比例,而不是R2。R2一般应用于回归问题。
3、本例先根据样本中身高体重的最大最小值,生成了一个密集网格(步长h=0.01),然后将网格中的每一个点都当成测试样本去测试,最后使用contourf函数,使用不同的颜色标注出了胖、廋两类。
容易看到,本例的分类边界,属于相对复杂,但却又与距离呈现明显规则的锯齿形。
这种边界线性函数是难以处理的。而KNN算法处理此类边界问题具有天生的优势。我们在后续的系列中会看到,这个数据集达到准确率=0.94算是很优秀的结果了。
来源:https://zhuanlan.zhihu.com/p/561759692
猜你喜欢
- 有时在项目中会遇到通过在页面中采用iframe的方式include其它页面,这时就会考虑不要因出现滚动条而影响页面效果,但include页面
- 这个仿msn的右下角popup提示窗口效果很久以前收集的,现在整理出来给大家分享,需要的朋友可以拿去用,特点,提示窗口内容和js代码分离容易
- 前言Golang语言有诸多优点:静态编译、协程、堪比c语言的高性能。但是也有一些令人发指的地方 —— 经常被人调侃 五行代码,三行错误处理
- 本文实例讲述了python实现的分析并统计nginx日志数据功能。分享给大家供大家参考,具体如下:利用python脚本分析nginx日志内容
- 很多时候关心的是优化SELECT 查询,因为它们是最常用的查询,而且确定怎样优化它们并不总是直截了当。相对来说,将数据装入数据库是直截了当的
- 物质世界客观存在,而人的“视觉成像”是对当前世界的“唯心”重建。这种重建基于个人“经验”、“感知”和“集体意识”。最初科学家认为人类通过视觉
- 代码如下:--销售冠军 --问题:在公司中,老板走进来,要一张每个地区销量前3名的销售额与销售员的报表 --- create t
- 前言责任链模式(Chain of Responsibility Pattern)是什么?责任链模式是一种行为型模式,它允许多个对象将请求沿着
- 在当前的Web设计中,jQuery被越来越多地应用在Web开发中,之所以jQuery收到如此程度的欢迎,除了其本身具备的优秀易读易操作的代码
- 一、表单验证form1、创建一个新的表单:<form id="id是唯一的,不可重复" name=“可重复”,me
- Gmail 作为一个经典的 Web 2.0 应用,在带来革命性的邮件管理体验的同时,以其完整、快速的 AJAX 操作方式,深受用户的推崇和技
- 这个代表显示宽度整数列的显示宽度与mysql需要用多少个字符来显示该列数值,与该整数需要的存储空间的大小都没有关系
- 在工控应用上,返回的数据经常会以二进制的形成存储,而这些二进制数据又是以每4个bit表示一个十六进制的数据内容。解析的时候,往往是一个字节(
- 看看上一篇《javascript设计模式交流(一)Singleton Pattern》本文将讨论Prototype Pattern的js实现
- 下面我们以论坛排行榜举例说明:<% @ LANGUAGE="VBSCRIPT" %&
- 1、使用mysqldump工具将MySql数据库备份mysqldump -u root -p -c --default-character-
- 以下为测试例子。 1.首先创建两张临时表并录入测试数据: 代码如下:create table #temptest1 ( id i
- 说明Python标准库为我们提供了threading和multiprocessing模块编写相应的多线程/多进程代码。从Python3.2开
- sql server中变量要先申明后赋值:局部变量用一个@标识,全局变量用两个@(常用的全局变量一般都是已经定义好的);申明局部变量语法:d
- Python 中的 timeit 模块可以用来测试一段代码的执行耗时,如一个变量赋值语句的执行时间,一个函数的运行时间等。timeit 模块