python实现KNN分类算法
作者:王念晨 发布时间:2023-03-01 07:53:36
一、KNN算法简介
邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。 kNN方法在类别决策时,只与极少量的相邻样本有关。由于kNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,kNN方法较其他方法更为适合。
二、算法过程
1.读取数据集
2.处理数据集数据 清洗,采用留出法hold-out拆分数据集:训练集、测试集
3.实现KNN算法类:
1)遍历训练数据集,离差平方和计算各点之间的距离
2)对各点的距离数组进行排序,根据输入的k值取对应的k个点
3)k个点中,统计每个点出现的次数,权重为距离的导数,得到最大的值,该值的索引就是我们计算出的判定类别
三、代码实现及数据分析
import numpy as np
import pandas as pd
# 读取鸢尾花数据集,header参数来指定标题的行。默认为0。如果没有标题,则使用None。
data = pd.read_csv("你的目录/Iris.csv",header=0)
# 显示前n行记录。默认n的值为5。
#data.head()
# 显示末尾的n行记录。默认n的值为5。
#data.tail()
# 随机抽取样本。默认抽取一条,我们可以通过参数进行指定抽取样本的数量。
# data.sample(10)
# 将类别文本映射成为数值类型
data["Species"] = data["Species"].map({"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2})
# 删除不需要的Id列。
data.drop("Id", axis=1, inplace=True )
data.drop_duplicates(inplace=True)
## 查看各个类别的鸢尾花具有多少条记录。
data["Species"].value_counts()
分析:首先读取数据集,如下图
最后一列为数据集的分类名称,但是在程序中,我们更倾向于使用如0、1、2数字来表示分类,所以对数据集进行处理,处理后的数据集如下:
然后采用留出法对数据集进行拆分,一部分用作训练,一部分用作测试,如下图:
#构建训练集与测试集,用于对模型进行训练与测试。
# 提取出每个类比的鸢尾花数据
t0 = data[data["Species"] == 0]
t1 = data[data["Species"] == 1]
t2 = data[data["Species"] == 2]
# 对每个类别数据进行洗牌 random_state 每次以相同的方式洗牌 保证训练集与测试集数据取样方式相同
t0 = t0.sample(len(t0), random_state=0)
t1 = t1.sample(len(t1), random_state=0)
t2 = t2.sample(len(t2), random_state=0)
# 构建训练集与测试集。
train_X = pd.concat([t0.iloc[:40, :-1], t1.iloc[:40, :-1], t2.iloc[:40, :-1]] , axis=0)#截取前40行,除最后列外的列,因为最后一列是y
train_y = pd.concat([t0.iloc[:40, -1], t1.iloc[:40, -1], t2.iloc[:40, -1]], axis=0)
test_X = pd.concat([t0.iloc[40:, :-1], t1.iloc[40:, :-1], t2.iloc[40:, :-1]], axis=0)
test_y = pd.concat([t0.iloc[40:, -1], t1.iloc[40:, -1], t2.iloc[40:, -1]], axis=0)
实现KNN算法类:
#定义KNN类,用于分类,类中定义两个预测方法,分为考虑权重不考虑权重两种情况
class KNN:
''' 使用Python语言实现K近邻算法。(实现分类) '''
def __init__(self, k):
'''初始化方法
Parameters
-----
k:int 邻居的个数
'''
self.k = k
def fit(self,X,y):
'''训练方法
Parameters
----
X : 类数组类型,形状为:[样本数量, 特征数量]
待训练的样本特征(属性)
y : 类数组类型,形状为: [样本数量]
每个样本的目标值(标签)。
'''
#将X转换成ndarray数组
self.X = np.asarray(X)
self.y = np.asarray(y)
def predict(self,X):
"""根据参数传递的样本,对样本数据进行预测。
Parameters
-----
X : 类数组类型,形状为:[样本数量, 特征数量]
待训练的样本特征(属性)
Returns
-----
result : 数组类型
预测的结果。
"""
X = np.asarray(X)
result = []
# 对ndarray数组进行遍历,每次取数组中的一行。
for x in X:
# 对于测试集中的每一个样本,依次与训练集中的所有样本求距离。
dis = np.sqrt(np.sum((x - self.X) ** 2, axis=1))
## 返回数组排序后,每个元素在原数组(排序之前的数组)中的索引。
index = dis.argsort()
# 进行截断,只取前k个元素。【取距离最近的k个元素的索引】
index = index[:self.k]
# 返回数组中每个元素出现的次数。元素必须是非负的整数。【使用weights考虑权重,权重为距离的倒数。】
count = np.bincount(self.y[index], weights= 1 / dis[index])
# 返回ndarray数组中,值最大的元素对应的索引。该索引就是我们判定的类别。
# 最大元素索引,就是出现次数最多的元素。
result.append(count.argmax())
return np.asarray(result)
#创建KNN对象,进行训练与测试。
knn = KNN(k=3)
#进行训练
knn.fit(train_X,train_y)
#进行测试
result = knn.predict(test_X)
# display(result)
# display(test_y)
display(np.sum(result == test_y))
display(np.sum(result == test_y)/ len(result))
得出计算结果:
26
0.9629629629629629
得出该模型计算的结果中,有26条记录与测试集相等,准确率为96%
接下来绘制散点图:
#导入可视化所必须的库。
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams["font.family"] = "SimHei"
mpl.rcParams["axes.unicode_minus"] = False
#绘制散点图。为了能够更方便的进行可视化,这里只选择了两个维度(分别是花萼长度与花瓣长度)。
# {"Iris-virginica": 0, "Iris-setosa": 1, "Iris-versicolor": 2})
# 设置画布的大小
plt.figure(figsize=(10, 10))
# 绘制训练集数据
plt.scatter(x=t0["SepalLengthCm"][:40], y=t0["PetalLengthCm"][:40], color="r", label="Iris-virginica")
plt.scatter(x=t1["SepalLengthCm"][:40], y=t1["PetalLengthCm"][:40], color="g", label="Iris-setosa")
plt.scatter(x=t2["SepalLengthCm"][:40], y=t2["PetalLengthCm"][:40], color="b", label="Iris-versicolor")
# 绘制测试集数据
right = test_X[result == test_y]
wrong = test_X[result != test_y]
plt.scatter(x=right["SepalLengthCm"], y=right["PetalLengthCm"], color="c", marker="x", label="right")
plt.scatter(x=wrong["SepalLengthCm"], y=wrong["PetalLengthCm"], color="m", marker=">", label="wrong")
plt.xlabel("花萼长度")
plt.ylabel("花瓣长度")
plt.title("KNN分类结果显示")
plt.legend(loc="best")
plt.show()
程序运行结果如下:
四、思考与优化
①尝试去改变邻居的数量。
②在考虑权重的情况下,修改邻居的数量。
③对比查看结果上的差异。
来源:https://blog.csdn.net/tanak/article/details/84380362


猜你喜欢
- 本文实例讲述了ES6正则表达式和字符串正则方法。分享给大家供大家参考,具体如下:RegExp构造函数在ES5中,RegExp构造函数的参数有
- 简洁的隐藏垂直菜单在hover时将内容展开。这样的效果在JS里有很多个版本,但这个可以说是绝无仅有的CSS版本。此菜单可以在IE5.5,IE
- 聚合函数作用于一组数据,对那组数据返回一个值count :统计结果记录多少条数,max:统计最大值min:统计最小值sum:计算求和avg:
- 菜鸟版代码如下: 理解这段代码就基本上掌握了 function f_s() { var obj = document.getElementB
- 使用pyserial进行串口传输一、安装pyserial以及基本用法在cmd下输入命令pip install pyserial注:升级pip
- 1.join函数的语法及用法(1)语法:'sep'.join(sep_object)参数说明sep:分割符,可为&l
- 特征降维0维 标量1维 向量2维 矩阵概念降维是指在某些限定条件下,降低随机变量(特征)个数,得到一组“不相关&
- 1.表达式操作符Table 1 算术操作符操作符 语法 含义+ a + b 相加 - a - b 相减 - - a
- 【问】使用FCKeditor添加文章时,在文章最后多了逗号。【答】此情况发生在asp环境中。在asp里对于 提交的表单信息中如果有相同nam
- Python中对于数组和列表进行切片操作是很频繁的,当然对于切片的操作可供我们直接使用的函数也是很遍历了,我们今天主要简单总结一下常用集中索
- (一)前言众所周知,Navicat是我们常用的连接MYSQL工具,非常方便好用。其实日常中,我们也常常会遇到运行时间很长甚至几乎跑不完卡死的
- 1.如何通过地址栏参数来得到模块名称和控制器名称(即使在有路由和开了重写模块的情况下)2.tp是如何实现前置,后置方 * 能模块,和如何执行带
- 轮播图的根本其实就是缓动函数的封装,如果说轮播图是一辆跑动的汽车,那么缓动函数就是它的发动机,今天本文章就带大家由简入繁,封装属于自己的缓动
- 今天偶尔在知乎上看到某大佬用Python写的ATM系统案例,然后观摩了下他的实现思路和源码,感觉受益颇多,于是就根据自己的思路和目前掌握的P
- 因为我们现在的前端框架做性能优化,为了找到各个组件及框架的具体解析耗时,需要在框架中嵌入一个耗时测试工具,性能测试跟不同的计算机硬件配置有很
- 最近在工作遇到一个难题。我所在的测试组有一套PC软件前端自动化工程,在进行自动化测试时,需要在一台古老的xp机器上运行,但这台古老的xp机器
- 1. 扩展Tensor维度相信刚接触Pytorch的宝宝们,会遇到这样一个问题,输入的数据维度和实验需要维度不一致,输入的可能是2维数据或3
- 首先,这次讲解的tansforms功能,通俗地讲,类似于在计算机视觉流程里的图像预处理部分的数据增强。transforms的原理:说明:图片
- 1.环境设置1.1gradio安装需要安装 gradio,安装办法就是 pip install gradio2.ffmpeg安装再次需要加入
- 周一 至 周日 时间格式化转化(Y --- 年 M --- 月 D--- 天)