Python实现聚类K-means算法详解
作者:Castria 发布时间:2023-04-22 07:48:36
K-means(K均值)算法是最简单的一种聚类算法,它期望最小化平方误差
注:为避免运行时间过长,通常设置一个最大运行轮数或最小调整幅度阈值,若到达最大轮数或调整幅度小于阈值,则停止运行。
下面我们用python来实现一下K-means算法:我们先尝试手动实现这个算法,再用sklearn
库中的KMeans
类来实现。数据我们采用《机器学习》的西瓜数据(P202表9.1):
# 下面的内容保存在 melons.txt 中
# 第一列为西瓜的密度;第二列为西瓜的含糖率。我们要把这30个西瓜分为3类
0.697 0.460
0.774 0.376
0.634 0.264
0.608 0.318
0.556 0.215
0.403 0.237
0.481 0.149
0.437 0.211
0.666 0.091
0.243 0.267
0.245 0.057
0.343 0.099
0.639 0.161
0.657 0.198
0.360 0.370
0.593 0.042
0.719 0.103
0.359 0.188
0.339 0.241
0.282 0.257
0.748 0.232
0.714 0.346
0.483 0.312
0.478 0.437
0.525 0.369
0.751 0.489
0.532 0.472
0.473 0.376
0.725 0.445
0.446 0.459
手动实现
我们用到的库有matplotlib
和numpy
,如果没有需要先用pip安装一下。
import random
import numpy as np
import matplotlib.pyplot as plt
下面定义一些数据:
k = 3 # 要分的簇数
rnd = 0 # 轮次,用于控制迭代次数(见上文)
ROUND_LIMIT = 100 # 轮次的上限
THRESHOLD = 1e-10 # 单轮改变距离的阈值,若改变幅度小于该阈值,算法终止
melons = [] # 西瓜的列表
clusters = [] # 簇的列表,clusters[i]表示第i簇包含的西瓜
从melons.txt读取数据,保存在列表中:
f = open('melons.txt', 'r')
for line in f:
# 把字符串转化为numpy中的float64类型
melons.append(np.array(line.split(' '), dtype = np.string_).astype(np.float64))
从 m m m个数据中随机挑选出 k k k个,对应上面算法的第 1 1 1行:
# random的sample函数从列表中随机挑选出k个样本(不重复)。我们在这里把这些样本作为均值向量
mean_vectors = random.sample(melons, k)
下面是算法的主要部分。
# 这个while对应上面算法的2-17行
while True:
rnd += 1 # 轮次增加
change = 0 # 把改变幅度重置为0
# 清空对簇的划分,对应上面算法的第3行
clusters = []
for i in range(k):
clusters.append([])
# 这个for对应上面算法的4-8行
for melon in melons:
'''
argmin 函数找出容器中最小的下标,在这里这个目标容器是
list(map(lambda vec: np.linalg.norm(melon - vec, ord = 2), mean_vectors)),
它表示melon与mean_vectors中所有向量的距离列表。
(numpy.linalg.norm计算向量的范数,ord = 2即欧几里得范数,或模长)
'''
c = np.argmin(
list(map( lambda vec: np.linalg.norm(melon - vec, ord = 2), mean_vectors))
)
clusters[c].append(melon)
# 这个for对应上面算法的9-16行
for i in range(k):
# 求每个簇的新均值向量
new_vector = np.zeros((1,2))
for melon in clusters[i]:
new_vector += melon
new_vector /= len(clusters[i])
# 累加改变幅度并更新均值向量
change += np.linalg.norm(mean_vectors[i] - new_vector, ord = 2)
mean_vectors[i] = new_vector
# 若超过设定的轮次或者变化幅度<预先设定的阈值,结束算法
if rnd > ROUND_LIMIT or change < THRESHOLD:
break
print('最终迭代%d轮'%rnd)
最后我们绘图来观察一下划分的结果:
colors = ['red', 'green', 'blue']
# 每个簇换一下颜色,同时迭代簇和颜色两个列表
for i, col in zip(range(k), colors):
for melon in clusters[i]:
# 绘制散点图
plt.scatter(melon[0], melon[1], color = col)
plt.show()
划分结果(由于最开始的 k k k个均值向量随机选取,每次划分的结果可能会不同):
完整代码:
import random
import numpy as np
import matplotlib.pyplot as plt
k = 3
rnd = 0
ROUND_LIMIT = 10
THRESHOLD = 1e-10
melons = []
clusters = []
f = open('melons.txt', 'r')
for line in f:
melons.append(np.array(line.split(' '), dtype = np.string_).astype(np.float64))
mean_vectors = random.sample(melons, k)
while True:
rnd += 1
change = 0
clusters = []
for i in range(k):
clusters.append([])
for melon in melons:
c = np.argmin(
list(map( lambda vec: np.linalg.norm(melon - vec, ord = 2), mean_vectors))
)
clusters[c].append(melon)
for i in range(k):
new_vector = np.zeros((1,2))
for melon in clusters[i]:
new_vector += melon
new_vector /= len(clusters[i])
change += np.linalg.norm(mean_vectors[i] - new_vector, ord = 2)
mean_vectors[i] = new_vector
if rnd > ROUND_LIMIT or change < THRESHOLD:
break
print('最终迭代%d轮'%rnd)
colors = ['red', 'green', 'blue']
for i, col in zip(range(k), colors):
for melon in clusters[i]:
plt.scatter(melon[0], melon[1], color = col)
plt.show()
sklearn库中的KMeans
这种经典算法显然不需要我们反复地造轮子,被广泛使用的python机器学习库sklearn
已经提供了该算法的实现。sklearn
的官方文档中给了我们一个示例:
>>> from sklearn.cluster import KMeans
>>> import numpy as np
>>> X = np.array([[1, 2], [1, 4], [1, 0],
... [10, 2], [10, 4], [10, 0]])
>>> kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
>>> kmeans.labels_
array([1, 1, 1, 0, 0, 0], dtype=int32)
>>> kmeans.predict([[0, 0], [12, 3]])
array([1, 0], dtype=int32)
>>> kmeans.cluster_centers_
array([[10., 2.],
[ 1., 2.]])
可以看出,X
即要聚类的数据(1,2),(1,4),(1,0)
等。KMeans
类的初始化参数n_clusters
即簇数 k k k;random_state
是用于初始化选取 k k k个向量的随机数种子;kmeans.labels_
即每个点所属的簇;kmeans.predict
方法预测新的数据属于哪个簇;kmeans.cluster_centers_
返回每个簇的中心。
我们就改造一下这个简单的示例,完成对上面西瓜的聚类。
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
X = []
f = open('melons.txt', 'r')
for line in f:
X.append(np.array(line.split(' '), dtype = np.string_).astype(np.float64))
kmeans = KMeans(n_clusters = 3, random_state = 0).fit(X)
colors = ['red', 'green', 'blue']
for i, cluster in enumerate(kmeans.labels_):
plt.scatter(X[i][0], X[i][1], color = colors[cluster])
plt.show()
运行结果如下,可以看到和我们手写的聚类结果基本一致:
来源:https://blog.csdn.net/wyn1564464568/article/details/125782286


猜你喜欢
- 装饰器的应用场景附加功能数据的清理或添加:函数参数类型验证 @require_ints 类似请求前拦截数据格式转换 将函数返回字典改为 JS
- 从低版本迁移到MySQL 8后,可能由于字符集问题出现 Illegal mix of collations (utf8mb4_general
- 什么是XML?XML 指可扩展标记语言(eXtensibleMarkupLanguage)。 你可以通过本站学习XML教程XML 被设计用来
- sql 使用系统存储过程 sp_send_dbmail 发送电子邮件语法:sp_send_dbmail [ [ @profile_name
- 前言python学习之路任重而道远,要想学完说容易也容易,说难也难。很多人说python最好学了,但扪心自问,你会用python做什么了?刚
- 说明C# 调用 Python 程序有多种方式,本篇用的是第 4 种:nuget的ironPython;用 c/c++ 调用python,再封
- 首先下载源tar包可利用linux自带下载工具wget下载,如下所示:wget http://www.python.org/ftp/pyth
- MinTTY 是一个小巧但却很实用的 Cygwin 终端机,但有个严重的问题就是无法调用交互性的 Windows 原生程序,比如说 mysq
- 1、日期大小的比较,传到xml中的日期格式要符合'yyyy-MM-dd',这样才能走索引,如:'yyyy'改
- 网络上的两个程序通过一个双向的通信连接实现数据的交换,这个连接的一端称为一个socket,一般在配置部署mysql环境时都会在mysql的m
- 在pytest自动化测试中,如果只是简单的从应用的角度来说,完全可以不去了解pytest中的显示信息的部分以及原理,完全可以通过使用推荐的p
- 一个动态载入asp树源码。把 node.htc, style.css 保存与 css 目录下. index.asp subtree.asp
- 前言使用 webstrom 调试 Vue.js 单页面程序,理论上来说应该是支持所有用 webpack 构建的应用程序webstrom 版本
- SQL Server有几个版本都在使用中——4.2, 6.0, 6.5, 7.0, 2000,以及2
- 当我们使用访问一个没有声明的变量时,JS会报错;而当我们给一个没有声明的变量赋值时,JS不会报错误,相反它会认为我们是要隐式申明一个全局变量
- 先记下,免得以后想不起来又到处去找! PHP操作数据库的时候,数据库中数据使用UTF8编码,在读出来的时候,显示的全是???????问号乱码
- Python 是一种极其多样化和强大的编程语言!当需要解决一个问题时,它有着不同的方法。在本文中,将会展示列表解析式(List Compre
- 使用Python解析各种格式的数据都很方便,比如json、txt、xml、csv等。用于处理简单的数据完全足够用了,而且代码简单易懂。前段时
- 前言由于pycharm自带的pip源网站是国外网址,这就导致了许多国内用户在pycharm中下载其他软件包速度极慢,有时还会跳出下载失败的界
- 我要坦白一点。尽管我是一个应用相当广泛的公共域 Python 库的创造者,但在我的模块中引入的单元测试是非常不系统的。实际上,那些测试大部分