Python实现的KMeans聚类算法实例分析
作者:njulpy 发布时间:2021-06-28 04:44:49
本文实例讲述了Python实现的KMeans聚类算法。分享给大家供大家参考,具体如下:
菜鸟一枚,编程初学者,最近想使用Python3实现几个简单的机器学习分析方法,记录一下自己的学习过程。
关于KMeans算法本身就不做介绍了,下面记录一下自己遇到的问题。
一 、关于初始聚类中心的选取
初始聚类中心的选择一般有:
(1)随机选取
(2)随机选取样本中一个点作为中心点,在通过这个点选取距离其较大的点作为第二个中心点,以此类推。
(3)使用层次聚类等算法更新出初始聚类中心
我一开始是使用numpy随机产生k个聚类中心
Center = np.random.randn(k,n)
但是发现聚类的时候迭代几次以后聚类中心会出现nan,有点搞不清楚怎么回事
所以我分别尝试了:
(1)选择数据集的前K个样本做初始中心点
(2)选择随机K个样本点作为初始聚类中心
发现两者都可以完成聚类,我是用的是iris.csv数据集,在选择前K个样本点做数据集时,迭代次数是固定的,选择随机K个点时,迭代次数和随机种子的选取有关,而且聚类效果也不同,有的随机种子聚类快且好,有的慢且差。
def InitCenter(k,m,x_train):
#Center = np.random.randn(k,n)
#Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
Center = np.zeros([k,n]) #从样本中随机取k个点做初始聚类中心
np.random.seed(5) #设置随机数种子
for i in range(k):
x = np.random.randint(m)
Center[i] = np.array(x_train.iloc[x])
return Center
二 、关于类间距离的选取
为了简单,我直接采用了欧氏距离,目前还没有尝试其他的距离算法。
def GetDistense(x_train, k, m, Center):
Distence=[]
for j in range(k):
for i in range(m):
x = np.array(x_train.iloc[i, :])
a = x.T - Center[j]
Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
Distence.append(Dist)
Dis_array = np.array(Distence).reshape(k,m)
return Dis_array
三 、关于终止聚类条件的选取
关于聚类的终止条件有很多选择方法:
(1)迭代一定次数
(2)聚类中心的更新小于某个给定的阈值
(3)类中的样本不再变化
我用的是前两种方法,第一种很简单,但是聚类效果不好控制,针对不同数据集,稳健性也不够。第二种比较合适,稳健性也强。第三种方法我还没有尝试,以后可以试着用一下,可能聚类精度会更高一点。
def KMcluster(x_train,k,n,m,threshold):
global axis_x, axis_y
center = InitCenter(k,m,x_train)
initcenter = center
centerChanged = True
t=0
while centerChanged:
Dis_array = GetDistense(x_train, k, m, center)
center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
err = np.linalg.norm(initcenter[-k:] - center)
print(err)
t+=1
plt.figure(1)
p=plt.subplot(3, 3, t)
p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
p.set_title('Iteration'+ str(t))
if err < threshold:
centerChanged = False
else:
initcenter = np.concatenate((initcenter, center), axis=0)
plt.show()
return center, axis_x, axis_y,axis_z, initcenter
err是本次聚类中心点和上次聚类中心点之间的欧氏距离。
threshold是人为设定的终止聚类的阈值,我个人一般设置为0.1或者0.01。
为了将每次迭代产生的类别显示出来我修改了上述代码,使用matplotlib展示每次迭代的散点图。
下面附上我测试数据时的图,子图设置的个数要根据迭代次数来定。
我测试了几个数据集,聚类的精度还是可以的。
使用iris数据集分析的结果为:
err of Iteration 1 is 3.11443180281
err of Iteration 2 is 1.27568813621
err of Iteration 3 is 0.198909381512
err of Iteration 4 is 0.0
Final cluster center is [[ 6.85 3.07368421 5.74210526 2.07105263]
[ 5.9016129 2.7483871 4.39354839 1.43387097]
[ 5.006 3.428 1.462 0.246 ]]
最后附上全部代码,错误之处还请多多批评,谢谢。
#encoding:utf-8
"""
Author: njulpy
Version: 1.0
Data: 2018/04/11
Project: Using Python to Implement KMeans Clustering Algorithm
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
def InitCenter(k,m,x_train):
#Center = np.random.randn(k,n)
#Center = np.array(x_train.iloc[0:k,:]) #取数据集中前k个点作为初始中心
Center = np.zeros([k,n]) #从样本中随机取k个点做初始聚类中心
np.random.seed(15) #设置随机数种子
for i in range(k):
x = np.random.randint(m)
Center[i] = np.array(x_train.iloc[x])
return Center
def GetDistense(x_train, k, m, Center):
Distence=[]
for j in range(k):
for i in range(m):
x = np.array(x_train.iloc[i, :])
a = x.T - Center[j]
Dist = np.sqrt(np.sum(np.square(a))) # dist = np.linalg.norm(x.T - Center)
Distence.append(Dist)
Dis_array = np.array(Distence).reshape(k,m)
return Dis_array
def GetNewCenter(x_train,k,n, Dis_array):
cen = []
axisx ,axisy,axisz= [],[],[]
cls = np.argmin(Dis_array, axis=0)
for i in range(k):
train_i=x_train.loc[cls == i]
xx,yy,zz = list(train_i.iloc[:,1]),list(train_i.iloc[:,2]),list(train_i.iloc[:,3])
axisx.append(xx)
axisy.append(yy)
axisz.append(zz)
meanC = np.mean(train_i,axis=0)
cen.append(meanC)
newcent = np.array(cen).reshape(k,n)
NewCent=np.nan_to_num(newcent)
return NewCent,axisx,axisy,axisz
def KMcluster(x_train,k,n,m,threshold):
global axis_x, axis_y
center = InitCenter(k,m,x_train)
initcenter = center
centerChanged = True
t=0
while centerChanged:
Dis_array = GetDistense(x_train, k, m, center)
center ,axis_x,axis_y,axis_z= GetNewCenter(x_train,k,n,Dis_array)
err = np.linalg.norm(initcenter[-k:] - center)
t+=1
print('err of Iteration '+str(t),'is',err)
plt.figure(1)
p=plt.subplot(2, 3, t)
p1,p2,p3 = plt.scatter(axis_x[0], axis_y[0], c='r'),plt.scatter(axis_x[1], axis_y[1], c='g'),plt.scatter(axis_x[2], axis_y[2], c='b')
plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
p.set_title('Iteration'+ str(t))
if err < threshold:
centerChanged = False
else:
initcenter = np.concatenate((initcenter, center), axis=0)
plt.show()
return center, axis_x, axis_y,axis_z, initcenter
if __name__=="__main__":
#x=pd.read_csv("8.Advertising.csv") # 两组测试数据
#x=pd.read_table("14.bipartition.txt")
x=pd.read_csv("iris.csv")
x_train=x.iloc[:,1:5]
m,n = np.shape(x_train)
k = 3
threshold = 0.1
km,ax,ay,az,ddd = KMcluster(x_train, k, n, m, threshold)
print('Final cluster center is ', km)
#2-Dplot
plt.figure(2)
plt.scatter(km[0,1],km[0,2],c = 'r',s = 550,marker='x')
plt.scatter(km[1,1],km[1,2],c = 'g',s = 550,marker='x')
plt.scatter(km[2,1],km[2,2],c = 'b',s = 550,marker='x')
p1, p2, p3 = plt.scatter(axis_x[0], axis_y[0], c='r'), plt.scatter(axis_x[1], axis_y[1], c='g'), plt.scatter(axis_x[2], axis_y[2], c='b')
plt.legend(handles=[p1, p2, p3], labels=['0', '1', '2'], loc='best')
plt.title('2-D scatter')
plt.show()
#3-Dplot
plt.figure(3)
TreeD = plt.subplot(111, projection='3d')
TreeD.scatter(ax[0],ay[0],az[0],c='r')
TreeD.scatter(ax[1],ay[1],az[1],c='g')
TreeD.scatter(ax[2],ay[2],az[2],c='b')
TreeD.set_zlabel('Z') # 坐标轴
TreeD.set_ylabel('Y')
TreeD.set_xlabel('X')
TreeD.set_title('3-D scatter')
plt.show()
附:上述示例中的iris.csv文件点击此处本站下载。
希望本文所述对大家Python程序设计有所帮助。
来源:https://blog.csdn.net/njulpy/article/details/79895750


猜你喜欢
- pyinstaller打包引入自己编写的库场景使用pyinstaller打包某个文件后,生成的exe文件无法运行,提示 Module not
- 1. 动态属性名:可使用表达式来设置动态属性名或方法名:<!-- 属性name --><a :[name]=&
- 原因:使用git clone项目后,项目根路径是小写英文名称,比如cmdbapi,但是项目里面的import导入自己的相关包时,红色报错解决
- 在Qt中经常需要使用样式,为了降低耦合性(与逻辑代码分离),我们通常会定义一个QSS文件,然后编写各种控件(QLabel,QLIneEdit
- 迭代器迭代器是一个实现了迭代器协议的对象,Python中的迭代器协议就是有next方法的对象会前进到下一结果,而在一系列结果的末尾是,则会引
- 原来看到这个代码的时候,觉得功能不错,调试并整理了一下也许大家用的到,对于越来越注重用户体验的今天,这种功能一定很受欢迎,当然本例子只是一个
- 项目使用Pyqt作为UI框架,使用相机线程捕捉image,并在QGraphicsView中显示,遇到以下问题:1、采集的数据为nparray
- 关于代码调试的技巧,我之前写过很多的文章,关注比较早的同学,也应该都有看过。还没看过的同学,欢迎前往查阅:更多Python代码调试技巧其中有
- 这个是用vue-cli生成的项目下使用比如有个路由跳转时需要带两个参数:<router-link to='/tr'&g
- 一、写在前面从学 Python 的第一天起,我就知道了使用 pip 命令来安装包,从学习爬虫到学习 Web 开发,安装的库越来越多,从 re
- 前 言在开发高并发系统时,我们可能会遇到接口访问频次过高,为了保证系统的高可用和稳定性,这时候就需要做流量限制,你可能是用的 Ng
- 简介在这篇文章中我将介绍如何写一个简短(200行)的 Python 脚本,来自动地将一幅图片的脸替换为另一幅图片的脸。这个过程分四步:检测脸
- 前言ctypes是Python的外部函数库。它提供了C兼容的数据类型,并且允许调用动态链接库/共享库中的函数。它可以将这些库包装起来给Pyt
- 如果你从未为MySQL设置根用户密码,服务器在以根用户身份进行连接时不需要密码。但是,建议你为每个账户设置密码。如果你以前设置了根用户密码,
- 本文实例讲述了JavaScript使用indexOf()实现数组去重的方法。分享给大家供大家参考,具体如下:数组去重方法有多中,这里列举出自
- 1、删除目录及目录下所有的文件2、删除目录下的所有文件但目录结构保留3、删除指定文件代码如下/** +-------------------
- 看代码吧~# -*- coding:utf-8 -*- import osimport jsonimport numpy as np #fr
- Problem:Solution:参考stackoverflow给出的解决方案:https://stackoverflow.com/ques
- 前言随着微信的普及,越来越多的人开始使用微信。微信渐渐从一款单纯的社交软件转变成了一个生活方式,人们的日常沟通需要微信,工作交流也需要微信。
- 最近小爬一直思忖着如何将以前写的一些半自动化程序转为全自动化,这其中就涉及到SAP的打开和登录过程。我们都知道,SAP原生的&ldq