闲得,自己琢磨了KMeans算法,记录下。原理网络一大把,不再累述
# -*- coding: utf-8 -*-
"""
Created on Wed May 16 23:02:51 2018
@author: mz
"""
import math
import random
from sklearn import datasets
import numpy as np
import copy
class Cluster(object):
def center(self, data, minPoints, k, n):
pass
def distance(self, source: list, center: list) -> float:
"""
计算两点之间的欧拉距离,支持多维
"""
dst = 0.0
for s, c in zip(source, center):
dst += math.pow(s - c, 2)
return math.sqrt(dst)
def minDistance(self, point, centroids):
min_dist = math.inf # 初始设为无穷大
index = -1;
for i, cen in enumerate(centroids):
dist = self.distance(cen, point)
if dist < min_dist:
min_dist = dist
index = i
return min_dist,index
class KMeans(Cluster):
"""
data : 输入数据
minPoints : 聚类的最小数据个数
K : 质心
第一次的質心遠離聚類效果越好?
"""
def randomCenter(self,data, k):
centers = [[0 for col in range(k)] for row in range(1)]
centers.append(random.choice(data)) #第一次随机选择质心
d = [0 for _ in range(len(data))]
for _ in range(1, k):
total = 0.0
for i, point in enumerate(data):
d[i],index = self.minDistance(point, centers)
total += d[i]
total *= random.random()
for i, di in enumerate(d): # 轮盘法选出下一个聚类中心; 、、此处参考另一高人
total -= di
if total > 0:
continue
centers.append(data[i])
break
return centers[1:]
def center(self, data, minPoints, k, n):
preCenters = self.randomCenter(data, k)
centers = copy.deepcopy(preCenters)
for _ in range(0, n):
cluster = [[0 for col in range(1)] for row in range(k)]
for c in range(0, len(cluster)):
cluster[c].pop(0)
for i, point in enumerate(data):
dist,index = self.minDistance(point, preCenters)
cluster[index].append(point)
for group in range(0, len(cluster)):
p = list(zip(*cluster[group]))
for j in range(len(p)):
centers[group][j] = sum(p[j])*1.0/len(p[j])
total = 0.0
for i, point in enumerate(centers):
dist = np.sqrt((np.mat(preCenters[i])-np.mat(point))*(np.mat(preCenters[i])-np.mat(point)).T)
total += dist
if total[0][0]/k < 0.0001:
break
preCenters = copy.deepcopy(centers)
return centers,cluster
if __name__ == "__main__":
iris = datasets.load_iris()
kmean = KMeans()
centers, clusters = kmean.center(iris.data,10, 3, 500)
print("centers = ", centers)
print("cluster[0] := ", clusters[0], "length := ", len(clusters[0]))
print("cluster[1] := ", clusters[1], "length := ", len(clusters[1]))
print("cluster[2] := ", clusters[2], "length := ", len(clusters[2]))
运行结果:
centers =
[array([ 6.85384615, 3.07692308, 5.71538462, 2.05384615]),
array([ 5.88360656, 2.74098361, 4.38852459, 1.43442623]),
array([ 5.006, 3.418, 1.464, 0.244])]
cluster[0] := [array([ 7. , 3.2, 4.7, 1.4]), array([ 6.9, 3.1, 4.9, 1.5]), array([ 6.7, 3. , 5. , 1.7]), array([ 6.3, 3.3, 6. , 2.5]), array([ 7.1, 3. , 5.9, 2.1]), array([ 6.3, 2.9, 5.6, 1.8]), array([ 6.5, 3. , 5.8, 2.2]), array([ 7.6, 3. , 6.6, 2.1]), array([ 7.3, 2.9, 6.3, 1.8]), array([ 6.7, 2.5, 5.8, 1.8]), array([ 7.2, 3.6, 6.1, 2.5]), array([ 6.5, 3.2, 5.1, 2. ]), array([ 6.4, 2.7, 5.3, 1.9]), array([ 6.8, 3. , 5.5, 2.1]), array([ 6.4, 3.2, 5.3, 2.3]), array([ 6.5, 3. , 5.5, 1.8]), array([ 7.7, 3.8, 6.7, 2.2]), array([ 7.7, 2.6, 6.9, 2.3]), array([ 6.9, 3.2, 5.7, 2.3]), array([ 7.7, 2.8, 6.7, 2. ]), array([ 6.7, 3.3, 5.7, 2.1]), array([ 7.2, 3.2, 6. , 1.8]), array([ 6.4, 2.8, 5.6, 2.1]), array([ 7.2, 3. , 5.8, 1.6]), array([ 7.4, 2.8, 6.1, 1.9]), array([ 7.9, 3.8, 6.4, 2. ]), array([ 6.4, 2.8, 5.6, 2.2]), array([ 6.1, 2.6, 5.6, 1.4]), array([ 7.7, 3. , 6.1, 2.3]), array([ 6.3, 3.4, 5.6, 2.4]), array([ 6.4, 3.1, 5.5, 1.8]), array([ 6.9, 3.1, 5.4, 2.1]), array([ 6.7, 3.1, 5.6, 2.4]), array([ 6.9, 3.1, 5.1, 2.3]), array([ 6.8, 3.2, 5.9, 2.3]), array([ 6.7, 3.3, 5.7, 2.5]), array([ 6.7, 3. , 5.2, 2.3]), array([ 6.5, 3. , 5.2, 2. ]), array([ 6.2, 3.4, 5.4, 2.3])] length := 39
cluster[1] := [array([ 6.4, 3.2, 4.5, 1.5]), array([ 5.5, 2.3, 4. , 1.3]), array([ 6.5, 2.8, 4.6, 1.5]), array([ 5.7, 2.8, 4.5, 1.3]), array([ 6.3, 3.3, 4.7, 1.6]), array([ 4.9, 2.4, 3.3, 1. ]), array([ 6.6, 2.9, 4.6, 1.3]), array([ 5.2, 2.7, 3.9, 1.4]), array([ 5. , 2. , 3.5, 1. ]), array([ 5.9, 3. , 4.2, 1.5]), array([ 6. , 2.2, 4. , 1. ]), array([ 6.1, 2.9, 4.7, 1.4]), array([ 5.6, 2.9, 3.6, 1.3]), array([ 6.7, 3.1, 4.4, 1.4]), array([ 5.6, 3. , 4.5, 1.5]), array([ 5.8, 2.7, 4.1, 1. ]), array([ 6.2, 2.2, 4.5, 1.5]), array([ 5.6, 2.5, 3.9, 1.1]), array([ 5.9, 3.2, 4.8, 1.8]), array([ 6.1, 2.8, 4. , 1.3]), array([ 6.3, 2.5, 4.9, 1.5]), array([ 6.1, 2.8, 4.7, 1.2]), array([ 6.4, 2.9, 4.3, 1.3]), array([ 6.6, 3. , 4.4, 1.4]), array([ 6.8, 2.8, 4.8, 1.4]), array([ 6. , 2.9, 4.5, 1.5]), array([ 5.7, 2.6, 3.5, 1. ]), array([ 5.5, 2.4, 3.8, 1.1]), array([ 5.5, 2.4, 3.7, 1. ]), array([ 5.8, 2.7, 3.9, 1.2]), array([ 6. , 2.7, 5.1, 1.6]), array([ 5.4, 3. , 4.5, 1.5]), array([ 6. , 3.4, 4.5, 1.6]), array([ 6.7, 3.1, 4.7, 1.5]), array([ 6.3, 2.3, 4.4, 1.3]), array([ 5.6, 3. , 4.1, 1.3]), array([ 5.5, 2.5, 4. , 1.3]), array([ 5.5, 2.6, 4.4, 1.2]), array([ 6.1, 3. , 4.6, 1.4]), array([ 5.8, 2.6, 4. , 1.2]), array([ 5. , 2.3, 3.3, 1. ]), array([ 5.6, 2.7, 4.2, 1.3]), array([ 5.7, 3. , 4.2, 1.2]), array([ 5.7, 2.9, 4.2, 1.3]), array([ 6.2, 2.9, 4.3, 1.3]), array([ 5.1, 2.5, 3. , 1.1]), array([ 5.7, 2.8, 4.1, 1.3]), array([ 5.8, 2.7, 5.1, 1.9]), array([ 4.9, 2.5, 4.5, 1.7]), array([ 5.7, 2.5, 5. , 2. ]), array([ 5.8, 2.8, 5.1, 2.4]), array([ 6. , 2.2, 5. , 1.5]), array([ 5.6, 2.8, 4.9, 2. ]), array([ 6.3, 2.7, 4.9, 1.8]), array([ 6.2, 2.8, 4.8, 1.8]), array([ 6.1, 3. , 4.9, 1.8]), array([ 6.3, 2.8, 5.1, 1.5]), array([ 6. , 3. , 4.8, 1.8]), array([ 5.8, 2.7, 5.1, 1.9]), array([ 6.3, 2.5, 5. , 1.9]), array([ 5.9, 3. , 5.1, 1.8])] length := 61
cluster[2] := [array([ 5.1, 3.5, 1.4, 0.2]), array([ 4.9, 3. , 1.4, 0.2]), array([ 4.7, 3.2, 1.3, 0.2]), array([ 4.6, 3.1, 1.5, 0.2]), array([ 5. , 3.6, 1.4, 0.2]), array([ 5.4, 3.9, 1.7, 0.4]), array([ 4.6, 3.4, 1.4, 0.3]), array([ 5. , 3.4, 1.5, 0.2]), array([ 4.4, 2.9, 1.4, 0.2]), array([ 4.9, 3.1, 1.5, 0.1]), array([ 5.4, 3.7, 1.5, 0.2]), array([ 4.8, 3.4, 1.6, 0.2]), array([ 4.8, 3. , 1.4, 0.1]), array([ 4.3, 3. , 1.1, 0.1]), array([ 5.8, 4. , 1.2, 0.2]), array([ 5.7, 4.4, 1.5, 0.4]), array([ 5.4, 3.9, 1.3, 0.4]), array([ 5.1, 3.5, 1.4, 0.3]), array([ 5.7, 3.8, 1.7, 0.3]), array([ 5.1, 3.8, 1.5, 0.3]), array([ 5.4, 3.4, 1.7, 0.2]), array([ 5.1, 3.7, 1.5, 0.4]), array([ 4.6, 3.6, 1. , 0.2]), array([ 5.1, 3.3, 1.7, 0.5]), array([ 4.8, 3.4, 1.9, 0.2]), array([ 5. , 3. , 1.6, 0.2]), array([ 5. , 3.4, 1.6, 0.4]), array([ 5.2, 3.5, 1.5, 0.2]), array([ 5.2, 3.4, 1.4, 0.2]), array([ 4.7, 3.2, 1.6, 0.2]), array([ 4.8, 3.1, 1.6, 0.2]), array([ 5.4, 3.4, 1.5, 0.4]), array([ 5.2, 4.1, 1.5, 0.1]), array([ 5.5, 4.2, 1.4, 0.2]), array([ 4.9, 3.1, 1.5, 0.1]), array([ 5. , 3.2, 1.2, 0.2]), array([ 5.5, 3.5, 1.3, 0.2]), array([ 4.9, 3.1, 1.5, 0.1]), array([ 4.4, 3. , 1.3, 0.2]), array([ 5.1, 3.4, 1.5, 0.2]), array([ 5. , 3.5, 1.3, 0.3]), array([ 4.5, 2.3, 1.3, 0.3]), array([ 4.4, 3.2, 1.3, 0.2]), array([ 5. , 3.5, 1.6, 0.6]), array([ 5.1, 3.8, 1.9, 0.4]), array([ 4.8, 3. , 1.4, 0.3]), array([ 5.1, 3.8, 1.6, 0.2]), array([ 4.6, 3.2, 1.4, 0.2]), array([ 5.3, 3.7, 1.5, 0.2]), array([ 5. , 3.3, 1.4, 0.2])] length := 50