python中k-means和k-means++原理及实现
作者:雷恩Layne 发布时间:2022-01-17 10:52:07
前言
k-means算法是无监督的聚类算法,实现起来较为简单,k-means++可以理解为k-means的增强版,在初始化中心点的方式上比k-means更友好。
k-means原理
k-means的实现步骤如下:
从样本中随机选取k个点作为聚类中心点
对于任意一个样本点,求其到k个聚类中心的距离,然后,将样本点归类到距离最小的聚类中心,直到归类完所有的样本点(聚成k类)
对每个聚类求平均值,然后将k个均值分别作为各自聚类新的中心点
重复2、3步,直到中心点位置不在变化或者中心点的位置变化小于阈值
优点:
原理简单,实现起来比较容易
收敛速度较快,聚类效果较优
缺点:
初始中心点的选取具有随机性,可能会选取到不好的初始值。
k-means++原理
k-means++是k-means的增强版,它初始选取的聚类中心点尽可能的分散开来,这样可以有效减少迭代次数,加快运算速度,实现步骤如下:
从样本中随机选取一个点作为聚类中心
计算每一个样本点到已选择的聚类中心的距离,用D(X)表示:D(X)越大,其被选取下一个聚类中心的概率就越大
利用轮盘法的方式选出下一个聚类中心(D(X)越大,被选取聚类中心的概率就越大)
重复步骤2,直到选出k个聚类中心
选出k个聚类中心后,使用标准的k-means算法聚类
这里不得不说明一点,有的文献中把与已选择的聚类中心最大距离的点选作下一个中心点,这个说法是不太准确的,准的说是与已选择的聚类中心最大距离的点被选作下一个中心点的概率最大,但不一定就是改点,因为总是取最大也不太好(遇到特殊数据,比如有一个点离某个聚类所有点都很远)。
一般初始化部分,始终要给些随机。因为数据是随机的。
尽管计算初始点时花费了额外的时间,但是在迭代过程中,k-mean 本身能快速收敛,因此算法实际上降低了计算时间。
现在重点是利用轮盘法的方式选出下一个聚类中心,我们以一个例子说明K-means++是如何选取初始聚类中心的。
假如数据集中有8个样本,分布分布以及对应序号如下图所示:
我们先用 k-means++的步骤1选择6号点作为第一个聚类中心,然后进行第二步,计算每个样本点到已选择的聚类中心的距离D(X),如下所示:
D(X)是每个样本点与所选取的聚类中心的距离(即第一个聚类中心)
P(X)每个样本被选为下一个聚类中心的概率
Sum是概率P(x)的累加和,用于轮盘法选择出第二个聚类中心。
然后执行 k-means++的第三步:利用轮盘法的方式选出下一个聚类中心,方法是随机产生出一个0~1之间的随机数,判断它属于哪个区间,那么该区间对应的序号就是被选择出来的第二个聚类中心了。
在上图1号点区间为[0,0.2),2号点的区间为[0.2, 0.525),4号点的区间为[0.65,0.9)
从上表可以直观的看到,1号,2号,3号,4号总的概率之和为0.9,这4个点正好是离第一个初始聚类中心(即6号点)较远的四个点,因此选取的第二个聚类中心大概率会落在这4个点中的一个,其中2号点被选作为下一个聚类中心的概率最大。
k-means及k-means++代码实现
这里选择的中心点是样本的特征(不是索引),这样做是为了方便计算,选择的聚类点(中心点周围的点)是样本的索引。
k-means实现
# 定义欧式距离
import numpy as np
def get_distance(x1, x2):
return np.sqrt(np.sum(np.square(x1-x2)))
import random
# 定义中心初始化函数,中心点选择的是样本特征
def center_init(k, X):
n_samples, n_features = X.shape
centers = np.zeros((k, n_features))
selected_centers_index = []
for i in range(k):
# 每一次循环随机选择一个类别中心,判断不让centers重复
sel_index = random.choice(list(set(range(n_samples))-set(selected_centers_index)))
centers[i] = X[sel_index]
selected_centers_index.append(sel_index)
return centers
# 判断一个样本点离哪个中心点近, 返回的是该中心点的索引
## 比如有三个中心点,返回的是0,1,2
def closest_center(sample, centers):
closest_i = 0
closest_dist = float('inf')
for i, c in enumerate(centers):
# 根据欧式距离判断,选择最小距离的中心点所属类别
distance = get_distance(sample, c)
if distance < closest_dist:
closest_i = i
closest_dist = distance
return closest_i
# 定义构建聚类的过程
# 每一个聚类存的内容是样本的索引,即对样本索引进行聚类,方便操作
def create_clusters(centers, k, X):
clusters = [[] for _ in range(k)]
for sample_i, sample in enumerate(X):
# 将样本划分到最近的类别区域
center_i = closest_center(sample, centers)
# 存放样本的索引
clusters[center_i].append(sample_i)
return clusters
# 根据上一步聚类结果计算新的中心点
def calculate_new_centers(clusters, k, X):
n_samples, n_features = X.shape
centers = np.zeros((k, n_features))
# 以当前每个类样本的均值为新的中心点
for i, cluster in enumerate(clusters): # cluster为分类后每一类的索引
new_center = np.mean(X[cluster], axis=0) # 按列求平均值
centers[i] = new_center
return centers
# 获取每个样本所属的聚类类别
def get_cluster_labels(clusters, X):
y_pred = np.zeros(np.shape(X)[0])
for cluster_i, cluster in enumerate(clusters):
for sample_i in cluster:
y_pred[sample_i] = cluster_i
#print('把样本{}归到{}类'.format(sample_i,cluster_i))
return y_pred
# 根据上述各流程定义kmeans算法流程
def Mykmeans(X, k, max_iterations,init):
# 1.初始化中心点
if init == 'kmeans':
centers = center_init(k, X)
else: centers = get_kmeansplus_centers(k, X)
# 遍历迭代求解
for _ in range(max_iterations):
# 2.根据当前中心点进行聚类
clusters = create_clusters(centers, k, X)
# 保存当前中心点
pre_centers = centers
# 3.根据聚类结果计算新的中心点
new_centers = calculate_new_centers(clusters, k, X)
# 4.设定收敛条件为中心点是否发生变化
diff = new_centers - pre_centers
# 说明中心点没有变化,停止更新
if diff.sum() == 0:
break
# 返回最终的聚类标签
return get_cluster_labels(clusters, X)
# 测试执行
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
# 设定聚类类别为2个,最大迭代次数为10次
labels = Mykmeans(X, k = 2, max_iterations = 10,init = 'kmeans')
# 打印每个样本所属的类别标签
print("最后分类结果",labels)
## 输出为 [1. 1. 1. 0. 0.]
# 使用sklearn验证
from sklearn.cluster import KMeans
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
kmeans = KMeans(n_clusters=2,init = 'random').fit(X)
# 由于center的随机性,结果可能不一样
print(kmeans.labels_)
k-means++实现
## 得到kmean++中心点
def get_kmeansplus_centers(k, X):
n_samples, n_features = X.shape
init_one_center_i = np.random.choice(range(n_samples))
centers = []
centers.append(X[init_one_center_i])
dists = [ 0 for _ in range(n_samples)]
# 执行
for _ in range(k-1):
total = 0
for sample_i,sample in enumerate(X):
# 得到最短距离
closet_i = closest_center(sample,centers)
d = get_distance(X[closet_i],sample)
dists[sample_i] = d
total += d
total = total * np.random.random()
for sample_i,d in enumerate(dists): # 轮盘法选出下一个聚类中心
total -= d
if total > 0:
continue
# 选取新的中心点
centers.append(X[sample_i])
break
return centers
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
# 设定聚类类别为2个,最大迭代次数为10次
labels = Mykmeans(X, k = 2, max_iterations = 10,init = 'kmeans++')
print("最后分类结果",labels)
## 输出为 [1. 1. 1. 0. 0.]
# 使用sklearn验证
X = np.array([[0,2],[0,0],[1,0],[5,0],[5,2]])
kmeans = KMeans(n_clusters=2,init='k-means++').fit(X)
print(kmeans.labels_)
参考文档
K-means与K-means++
K-means原理、优化及应用
来源:https://blog.csdn.net/qq_37555071/article/details/107599459
猜你喜欢
- 本来想把之前对artTemplate源码解析的注释放上来分享下,不过隔了一年,找不到了,只好把当时分析模板引擎原理后,自己尝试写下的模板引擎
- 本文实例讲述了JS实现跟随鼠标闪烁转动色块的方法。分享给大家供大家参考。具体实现方法如下:<html><head>&
- Get Started Tutorial for Python in Visual Studio Code一、安装PythonPython简
- python中,遍历dict的方法有四种。但这四种遍历的性能如何呢?我做了如下的测试l = [(x,x) for x in xrange(1
- 利用mask(掩模)技术提取纯色背景图像ROI区域中的人和物,并将提取出来的人或物添加在其他图像上。1、实现原理先通过cv.cvtColor
- python中查找指定的字符串的方法如下:code#查询def selStr(): sStr1 = 'jsjtt.com
- 今年年初,新一季的《最强大脑》开播了,第一集选拔的时候大家做了一个数字游戏,名叫《数字华容道》,当时何猷君以二十几秒的成绩夺得该项目的冠军,
- 1 squeeze(): 去除size为1的维度,包括行和列。至于维度大于等于2时,squeeze()不起作用。行、例:>>&g
- while循环只要循环条件为True(以下例子为x > y),while循环就会一直 执行下去:u, v, x, y = 0, 0,
- 设置cookie每个cookie都是一个名/值对,可以把下面这样一个字符串赋值给document.cookie:document.cooki
- 一种很常见的写法: document.write('<scr'+'ipt src=&quo
- 本文实例讲述了Python Datetime模块和Calendar模块用法。分享给大家供大家参考,具体如下:datetime模块1.1 概述
- 今天看视频学习时学习了一种新技术,即平时我们在一个页面点击“提交”或“确认”会自动跳转到一个页面。 在网上搜了一下,关于这个技术处理有多种方
- if语句>>通用格式if语句一般形式如下:if <test1>: <statements1>elif &
- 一、前言恭喜你,学明白类,你已经学会所有基本知识了。这章算是一个娱乐篇,十分简单,了解一下pyautogui模块,这算是比较好学还趣味性十足
- 日常维护中,经常会碰到线程被阻塞,导致数据库响应非常慢,下面就看看如何获取是哪个线程导致了阻塞的。1. 环境说明RHEL 6.4 x86_6
- 跟小组里一自称小方方的卖萌90小青年聊天,IT男的坏习惯,聊着聊着就扯到技术上去了,小方方突然问 1、声明一个数值类型的变量我看到三种,区别
- //冒泡排序func mpSort(array []int) { for i:=0;i<len(array);i++ {
- 一、变量的定义程序中,数据都是临时存储在内存中,为了更快速的查找或使用这个数据,通常我们把这个数据在内存中存储之后,给整个数据定义一个名称,
- #!/usr/bin/python# -*- coding: utf-8 -*-from scapy.all import *from ti