网络编程
位置:首页>> 网络编程>> Python编程>> python中k-means和k-means++原理及实现

python中k-means和k-means++原理及实现

作者:雷恩Layne  发布时间:2022-01-17 10:52:07 

标签:python,k-means,k-means++

前言

k-means算法是无监督的聚类算法,实现起来较为简单,k-means++可以理解为k-means的增强版,在初始化中心点的方式上比k-means更友好。

k-means原理

k-means的实现步骤如下:

  • 从样本中随机选取k个点作为聚类中心点

  • 对于任意一个样本点,求其到k个聚类中心的距离,然后,将样本点归类到距离最小的聚类中心,直到归类完所有的样本点(聚成k类)

  • 对每个聚类求平均值,然后将k个均值分别作为各自聚类新的中心点

  • 重复2、3步,直到中心点位置不在变化或者中心点的位置变化小于阈值

优点:

  • 原理简单,实现起来比较容易

  • 收敛速度较快,聚类效果较优

缺点:

  • 初始中心点的选取具有随机性,可能会选取到不好的初始值。

k-means++原理

k-means++是k-means的增强版,它初始选取的聚类中心点尽可能的分散开来,这样可以有效减少迭代次数,加快运算速度,实现步骤如下:

  • 从样本中随机选取一个点作为聚类中心

  • 计算每一个样本点到已选择的聚类中心的距离,用D(X)表示:D(X)越大,其被选取下一个聚类中心的概率就越大

  • 利用轮盘法的方式选出下一个聚类中心(D(X)越大,被选取聚类中心的概率就越大)

  • 重复步骤2,直到选出k个聚类中心

  • 选出k个聚类中心后,使用标准的k-means算法聚类

这里不得不说明一点,有的文献中把与已选择的聚类中心最大距离的点选作下一个中心点,这个说法是不太准确的,准的说是与已选择的聚类中心最大距离的点被选作下一个中心点的概率最大,但不一定就是改点,因为总是取最大也不太好(遇到特殊数据,比如有一个点离某个聚类所有点都很远)。

一般初始化部分,始终要给些随机。因为数据是随机的。

尽管计算初始点时花费了额外的时间,但是在迭代过程中,k-mean 本身能快速收敛,因此算法实际上降低了计算时间。

现在重点是利用轮盘法的方式选出下一个聚类中心,我们以一个例子说明K-means++是如何选取初始聚类中心的。

假如数据集中有8个样本,分布分布以及对应序号如下图所示:

python中k-means和k-means++原理及实现

我们先用 k-means++的步骤1选择6号点作为第一个聚类中心,然后进行第二步,计算每个样本点到已选择的聚类中心的距离D(X),如下所示:

python中k-means和k-means++原理及实现

  • D(X)是每个样本点与所选取的聚类中心的距离(即第一个聚类中心)

  • P(X)每个样本被选为下一个聚类中心的概率

  • Sum是概率P(x)的累加和,用于轮盘法选择出第二个聚类中心。

然后执行 k-means++的第三步:利用轮盘法的方式选出下一个聚类中心,方法是随机产生出一个0~1之间的随机数,判断它属于哪个区间,那么该区间对应的序号就是被选择出来的第二个聚类中心了

在上图1号点区间为[0,0.2),2号点的区间为[0.2, 0.525),4号点的区间为[0.65,0.9)

从上表可以直观的看到,1号,2号,3号,4号总的概率之和为0.9,这4个点正好是离第一个初始聚类中心(即6号点)较远的四个点,因此选取的第二个聚类中心大概率会落在这4个点中的一个,其中2号点被选作为下一个聚类中心的概率最大。

k-means及k-means++代码实现

这里选择的中心点是样本的特征(不是索引),这样做是为了方便计算,选择的聚类点(中心点周围的点)是样本的索引。

k-means实现

# 定义欧式距离
import numpy as np
def get_distance(x1, x2):
   return np.sqrt(np.sum(np.square(x1-x2)))
import random
# 定义中心初始化函数,中心点选择的是样本特征
def center_init(k, X):
   n_samples, n_features = X.shape
   centers = np.zeros((k, n_features))
   selected_centers_index = []
   for i in range(k):
       # 每一次循环随机选择一个类别中心,判断不让centers重复
       sel_index = random.choice(list(set(range(n_samples))-set(selected_centers_index)))
       centers[i] = X[sel_index]
       selected_centers_index.append(sel_index)
   return centers
# 判断一个样本点离哪个中心点近, 返回的是该中心点的索引
## 比如有三个中心点,返回的是0,1,2
def closest_center(sample, centers):
   closest_i = 0
   closest_dist = float('inf')
   for i, c in enumerate(centers):
       # 根据欧式距离判断,选择最小距离的中心点所属类别
       distance = get_distance(sample, c)
       if distance < closest_dist:
           closest_i = i
           closest_dist = distance
   return closest_i
# 定义构建聚类的过程
# 每一个聚类存的内容是样本的索引,即对样本索引进行聚类,方便操作
def create_clusters(centers, k, X):
   clusters = [[] for _ in range(k)]
   for sample_i, sample in enumerate(X):
       # 将样本划分到最近的类别区域
       center_i = closest_center(sample, centers)
       # 存放样本的索引
       clusters[center_i].append(sample_i)
   return clusters
# 根据上一步聚类结果计算新的中心点
def calculate_new_centers(clusters, k, X):
   n_samples, n_features = X.shape
   centers = np.zeros((k, n_features))
   # 以当前每个类样本的均值为新的中心点
   for i, cluster in enumerate(clusters):  # cluster为分类后每一类的索引
       new_center = np.mean(X[cluster], axis=0) # 按列求平均值
       centers[i] = new_center
   return centers
# 获取每个样本所属的聚类类别
def get_cluster_labels(clusters, X):
   y_pred = np.zeros(np.shape(X)[0])
   for cluster_i, cluster in enumerate(clusters):
       for sample_i in cluster:
           y_pred[sample_i] = cluster_i
           #print('把样本{}归到{}类'.format(sample_i,cluster_i))
   return y_pred
# 根据上述各流程定义kmeans算法流程
def Mykmeans(X, k, max_iterations,init):
   # 1.初始化中心点
   if init == 'kmeans':
       centers = center_init(k, X)
   else: centers = get_kmeansplus_centers(k, X)
   # 遍历迭代求解
   for _ in range(max_iterations):
       # 2.根据当前中心点进行聚类
       clusters = create_clusters(centers, k, X)
       # 保存当前中心点
       pre_centers = centers
       # 3.根据聚类结果计算新的中心点
       new_centers = calculate_new_centers(clusters, k, X)
       # 4.设定收敛条件为中心点是否发生变化
       diff = new_centers - pre_centers
       # 说明中心点没有变化,停止更新
       if diff.sum() == 0:
           break
   # 返回最终的聚类标签
   return get_cluster_labels(clusters, X)
# 测试执行
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
# 设定聚类类别为2个,最大迭代次数为10次
labels = Mykmeans(X, k = 2, max_iterations = 10,init = 'kmeans')
# 打印每个样本所属的类别标签
print("最后分类结果",labels)
## 输出为  [1. 1. 1. 0. 0.]
# 使用sklearn验证
from sklearn.cluster import KMeans
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
kmeans = KMeans(n_clusters=2,init = 'random').fit(X)
# 由于center的随机性,结果可能不一样
print(kmeans.labels_)

k-means++实现

## 得到kmean++中心点
def get_kmeansplus_centers(k, X):
   n_samples, n_features = X.shape
   init_one_center_i = np.random.choice(range(n_samples))
   centers = []
   centers.append(X[init_one_center_i])
   dists = [ 0 for _ in range(n_samples)]

# 执行
   for _ in range(k-1):
       total = 0
       for sample_i,sample in enumerate(X):
           # 得到最短距离
           closet_i = closest_center(sample,centers)
           d = get_distance(X[closet_i],sample)
           dists[sample_i] = d
           total += d
       total = total * np.random.random()

for sample_i,d in enumerate(dists): # 轮盘法选出下一个聚类中心
           total -= d
           if total > 0:
               continue
           # 选取新的中心点
           centers.append(X[sample_i])
           break
   return centers
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
# 设定聚类类别为2个,最大迭代次数为10次
labels = Mykmeans(X, k = 2, max_iterations = 10,init = 'kmeans++')
print("最后分类结果",labels)
## 输出为  [1. 1. 1. 0. 0.]
# 使用sklearn验证
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
kmeans = KMeans(n_clusters=2,init='k-means++').fit(X)
print(kmeans.labels_)

参考文档

K-means与K-means++
K-means原理、优化及应用

来源:https://blog.csdn.net/qq_37555071/article/details/107599459

0
投稿

猜你喜欢

  • 本文实例讲述了Python常见数据类型转换操作。分享给大家供大家参考,具体如下:类型转换主要针对几种存储工具:list、tuple、dict
  • LOAD DATA INFILE '文件地址' INTO TABLE 表名 FIELDS TERMINATED BY 
  • 标量(scalar)数据类型标量(scalar)数据类型没有内部组件,他们大致可分为以下四类:. number. character. da
  • 在Web开发的时候,经常会遇到的一种情况就是浏览器提示脚本运行时间过长,停止还是继续,无论你选择什么,相信你都会想尽一切办法让这个对话框远离
  • 静态页面是蜘蛛喜欢的,会得到蜘蛛经常光顾的,以至于网站上的内容会得到搜索引擎更多的收录。这里介绍一个asp伪静态的程序实现方法数据库是acc
  • 10个杀手级应用的Python自动化脚本重复的任务总是耗费时间和枯燥的。想象一下,逐一裁剪100张照片,或者做诸如Fetching APIs
  • 本文重在实践和测试,如果你还不了解Data URI,推荐先阅读秦歌的Data URI 和 MHTML。旺旺点灯(JS)实践经过:因为要对SR
  • 现在有一个xml,格式如下: <date> <item> <id> 1 </id> <
  • 前言前面写过一篇用Python制作PPT的博客,感兴趣的可以参考用Python制作PPT这篇是关于用Python进行数据可视化的,准备作为一
  • 上次帮朋友写过的一个简单切换效果,超级简单,但也比较适用.因为用到了CSS Sprite技术,DEMO中附带了IE6兼容png的JS.核心J
  • 前两天研究了一下textarea的直观行的换行规律,挺复杂啊:直观行怎样取不光要看cols大小,还要看网页编码方式。cols="3
  • 我就废话不多说了,大家还是直接看代码吧~one = tf.ones_like(label)zero = tf.zeros_like(labe
  • 平时我们在使用MySQL数据库的时候经常会因为操作失误造成数据丢失,MySQL数据库备份可以帮助我们避免由于各种原因造成的数据丢失或着数据库
  • 呵呵,我之前也写过一个类似的模板替换功能.>> 已实现:>、<、>=、<=、=、==等简单的运算>
  • 这阵子没有精力完整翻译和发到译言(  现下正渐入状态,预计写博客量会逐步提升回来),简短做一个概要翻译,为近期工作需要做一个参考。
  • 在CSS中我们会经常要用到“清除浮动”Clear,比较典型的就是clear:both;CSS手册上是这样说明的:该属性的值指出了不允许有浮动
  • 阅读作者的上一篇相关文章:段正淳的css笔记(3)标题右侧“更多”的实现 段正淳的css笔记(4)1、css代码的简写css缩写的语法,对新
  • 内容摘要合理使用渐变留白网格布局提高字体应用明确而有效的导航设计漂亮、有用的页脚介绍优秀设计和卓越设计之间的区别是比较小的。一般人可能无法解
  • 有时,希望除去某些记录或更改它们的内容。DELETE 和 UPDATE 语句令我们能做到这一点。用update修改记录UPDATE tbl_
  • 在网上有很多相关主题的讨论,但是一般都是用Iframe和XMLHTTP来实现。Iframe的实现可能是最常看到的。很多论坛和聊天室的无刷新效
手机版 网络编程 asp之家 www.aspxhome.com