本文主要分析了K-means聚类算法的基本原理,时间复杂度以及优缺点,最后用UCI数据集进行了测试,包含java实现代码,适合初学者参考。
一.算法原理
输入:聚类个数k,以及包含 n个数据对象的数据库。
输出:满足方差最小标准的k个聚类。
处理流程:
(1)从 n个数据对象任意选择 k 个对象作为初始聚类中心,即中心点。
(2)根据每个聚类的中心点,计算每个对象与这些中心点的距离;并根据最小距离重新对相应对象进行划分;
(3)重新计算每个(有变化)聚类的均值(中心点)
(4)循环(2)到(3)直到每个聚类不再发生变化为止
二.复杂度
时间复杂度:O(kntd),其中,t为迭代次数,K为簇的数目,n为数据数,d为维数
空间复杂度:O((n+K)d),其中,K为簇的数目,n为数据数,d为维数
三.算法优缺点
优点:
1.时间复杂度低,速度快。
2.对于大规模数据集,该算法是相对可扩展的,并且效率较高。
缺点:
1.必须事先给出要生成的簇数k。
2.不适合发现非凸面形状的簇和大小差别很大的簇。
3.对噪声和离群点敏感。
4.只适用于数值型数据。
5.初始点随机选取,可能导致终止于局部最优解。
四.java实现
实验数据用的是UCI上面的iris数据集,数据可以从UCI官网上面下载,下载下来是txt文件,可以自行百度将其导入到mysql数据库中。我的数据库中的字段信息如下图所示:(大家可以自己定义,原理看得明白就行)接下来直接贴实现代码。代码结构如下:
1.Point类,主要是对应数据库中字段的模型类
package Kmeans;
//模型类,对应数据库中的属性
public class Point
{
//定义iris数据集的四个属性
private double x;
private double y;
private double z;
private double w;
public double getX()
{
return x;
}
public void setX(double x)
{
this.x=x;
}
public double getY()
{
return y;
}
public void setY(double y)
{
this.y=y;
}
public double getZ()
{
return z;
}
public void setZ(double z)
{
this.z=z;
}
public double getW()
{
return w;
}
public void setW(double w)
{
this.w=w;
}
public Point()
{
}
public Point(double x,double y,double z,double w)
{
super();
this.x=x;
this.y=y;
this.z=z;
this.w=w;
}
public String toString()
{
return "Point [x="+x+",y="+y+",z="+z+",w="+w+"]";
}
//重写equals方法和hashCode方法,因为后面需要用到HashMap的containsKey(point)方法,而Point类作为其中的key参数
/*@Override
public boolean equals(Object obj)
{
Point p=(Point)obj;
if(this.getX()==p.getX()&&this.getY()==p.getY()&&this.getZ()==p.getZ()&&this.getW()==p.getW())
{
return true;
}
return false;
}
@Override
public int hashCode()
{
return (int)(this.x+this.y+this.z+this.w);
}*/
@Override
public int hashCode()
{
final int prime = 31;
int result = 1;
long temp;
temp = Double.doubleToLongBits(w);
result = prime * result + (int) (temp ^ (temp >>> 32));
temp = Double.doubleToLongBits(x);
result = prime * result + (int) (temp ^ (temp >>> 32));
temp = Double.doubleToLongBits(y);
result = prime * result + (int) (temp ^ (temp >>> 32));
temp = Double.doubleToLongBits(z);
result = prime * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public boolean equals(Object obj)
{
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Point other = (Point) obj;
if (Double.doubleToLongBits(w) != Double.doubleToLongBits(other.w))
return false;
if (Double.doubleToLongBits(x) != Double.doubleToLongBits(other.x))
return false;
if (Double.doubleToLongBits(y) != Double.doubleToLongBits(other.y))
return false;
if (Double.doubleToLongBits(z) != Double.doubleToLongBits(other.z))
return false;
return true;
}
}
ps:hashcode方法和equals方法必须重写,否则程序会出问题。
2.SelectData类,主要实现从数据库中读取数据到Arraylist中。
package Kmeans;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
//从数据库中读取数据的类
public class SelectData
{
private Connection con;
private PreparedStatement ps;
private ResultSet rs;
/*
* 从数据库中取数据存放到ArrayList中
*/
public ArrayList<Point> getPoints()
{
//定义存放数据的列表
ArrayList<Point> points=new ArrayList<Point>();
try
{
//连接数据库代码,先要加载mysql驱动
Class.forName("com.mysql.jdbc.Driver").newInstance();
con = DriverManager.getConnection("jdbc:mysql://localhost:3306/uci_dataset","root","asdzxc123");
String sql="select sepal_length,sepal_width,petal_length,petal_width from iris";
ps = con.prepareStatement(sql);
rs = ps.executeQuery();
while(rs.next())
{
Point p=new Point();
p.setX(rs.getDouble("sepal_length"));
p.setY(rs.getDouble("sepal_width"));
p.setZ(rs.getDouble("petal_length"));
p.setW(rs.getDouble("petal_width"));
points.add(p);
//System.out.println("数据集为: "+p);
}
/*for(Point pp:points)
{
System.out.println(pp);
}*/
//System.out.println("ArrayList数据集: "+points);
rs.close();
ps.close();
con.close();
}
catch (Exception e)
{
e.printStackTrace();
System.out.println("数据库连接失败");
}
return points;
}
/*public static void main(String[] args)
{
new SelectData().getPoints();
}*/
}
3.ManagePoint类,主要实现对中心点的更新等操作,具体函数看代码。
package Kmeans;
import java.util.ArrayList;
import java.util.Map;
public class ManagePoint
{
/**
* 计算对象点到中心点之间的距离
* @param p 对象点
* @param q 中心点
* @return 两点之间的距离
*/
public double getDistance(Point p,Point q)
{
double dx=p.getX()-q.getX();
double dy=p.getY()-q.getY();
double dz=p.getZ()-q.getZ();
double dw=p.getW()-q.getW();
double dist=dx*dx+dy*dy+dz*dz+dw*dw;
return dist;
}
/**
* 判断新的中心点是否和前一轮旧的中心的相同
* @param lastCenterCluster旧的
* @param nowCenterCluster新的
* @return 相同则返回true,否则返回false
*/
public boolean isEqual(Map<Point,ArrayList<Point>> lastCenterCluster,Map<Point,ArrayList<Point>> nowCenterCluster,int k)
{
boolean flag;
int i=0;
if(lastCenterCluster==null)
{
//System.out.println("11111111");
return false;
}
else
{
for(Point point:nowCenterCluster.keySet())
{
//System.out.println("222222");
flag=lastCenterCluster.containsKey(point);
if(flag)
{
i++;
}
}
if(i==k) return true;
}
//System.out.println("333333");
return false;
}
/**
* 计算新的中心点
* @param value HashMap中的value,为一个ArrayList
* @return 返回新的中心点
*/
public Point getNewCenter(ArrayList<Point> value)
{
double sumX=0,sumY=0,sumZ=0,sumW=0;
for(Point point:value)
{
sumX+=point.getX();
sumY+=point.getY();
sumZ+=point.getZ();
sumW+=point.getW();
}
System.out.println("新的中心: ("+sumX/value.size()+","+sumY/value.size()+","+sumZ/value.size()+","+sumW/value.size()+")");
Point point=new Point();
point.setX(sumX/value.size());
point.setY(sumY/value.size());
point.setZ(sumZ/value.size());
point.setW(sumW/value.size());
return point;
}
}
4.KmeansMain类,实现聚类。
package Kmeans;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
public class KmeansMain
{
public Map<Point,ArrayList<Point>> executeKmeans(int k)
{
ArrayList<Point> dataList=new ArrayList<Point>();//存放从SelectData类中获取的数据库中的源数据
Map<Point,ArrayList<Point>> nowCenterClusterMap=new HashMap<Point,ArrayList<Point>>();//当前中心及其簇内的点
Map<Point,ArrayList<Point>> lastCenterClusterMap=null;//上一个中心及其簇内所有点
dataList=new SelectData().getPoints();
// 随机创建K个点作为起始中心
Random rd=new Random();
System.out.println("起始中心下标: ");
for(int i=0;i<k;i++)
{
int index=rd.nextInt(150);//public int nextInt(int n)该方法的作用是生成一个随机的int值,该值介于[0,n)的区间,也就是0到n之间的随机int值,包含0而不包含n。
System.out.println("第"+(i+1)+"个随机生成的中心 :"+index);
nowCenterClusterMap.put(dataList.get(index),new ArrayList<Point>());
}
// 输出起始中心
System.out.println("起始中心: ");
for(Point point:nowCenterClusterMap.keySet())
{
System.out.println("key: "+point);
}
// 将数据点point加入配到离其最近的map的value中
ManagePoint managePoint=new ManagePoint();
while(true)
{
for(Point point:dataList)
{
double shortestDistance = Double.MAX_VALUE;// 初始化最短距离为Double的最大值
Point key = null;
for (Entry<Point,ArrayList<Point>> entry:nowCenterClusterMap.entrySet())
{
// 计算中心与各点间的距离
double distance=managePoint.getDistance(entry.getKey(),point);
if(distance<shortestDistance)
{
shortestDistance=distance;
key=entry.getKey();
}
}
nowCenterClusterMap.get(key).add(point);
}
//如果这个判断放到上面while之后,那么return的值变为lastclustermap即可,因为在每次更新中心之后,nowclustermap里面只有key,没有value,只有执行循环之后才有value,才可以返回
// 如果新的中心与上次的中心相等,则退出整个循环
if (managePoint.isEqual(lastCenterClusterMap,nowCenterClusterMap,k))
{
System.out.println("中心相等了,聚类结束!");
//测试lastCenterClusterMap数据,因为跳出循环时,它的数据应该和nowCenterClusterMap保持一致
/*for (Entry<Point,ArrayList<Point>> entry:lastCenterClusterMap.entrySet())
{
System.out.println("\n" + "稳定的中心: "+entry.getKey());
System.out.println("该簇的大小: "+entry.getValue().size());
System.out.println("簇里的点:"+entry.getValue());
}
System.out.println("中心相等了,聚类结束!!!!");*/
break;
}
// 更新中心
lastCenterClusterMap=nowCenterClusterMap;
nowCenterClusterMap=new HashMap<Point, ArrayList<Point>>();
System.out.println("------------------------------------------------------------------");
for(Entry<Point,ArrayList<Point>> entry:lastCenterClusterMap.entrySet())
{
nowCenterClusterMap.put(managePoint.getNewCenter(entry.getValue()),new ArrayList<Point>());
}
}
return nowCenterClusterMap;
}
public static void main(String[] args)
{
long start=System.currentTimeMillis();
int K=3;// 分为三个类
Map<Point,ArrayList<Point>> result =new KmeansMain().executeKmeans(K);
// 输出分类结果
System.out.println("===========聚类结果: ============");
for (Entry<Point,ArrayList<Point>> entry:result.entrySet())
{
System.out.println("\n" + "稳定的中心: "+entry.getKey());
System.out.println("该簇的大小: "+entry.getValue().size());
System.out.println("簇里的点:"+entry.getValue());
}
long end=System.currentTimeMillis();
System.out.println("执行本段程序所花费的时间为:"+(end-start)+"ms");
}
}
结果如图所示,可进行多次运行查看聚类效果。
执行时间一般为0.4s左右。我用一百万条数据的数据集测试过,执行速度也是非常快的,聚类效率还是挺不错的。
后期会慢慢写其他聚类算法的实现,尽请期待!
特别说明:k-means算法中的初始中心点是随机选取,但为了程序方便,当然实际应用中也不会随机选取,故在初始点的选取过程中采用的是k-means++的选取方式,即从数据点中随机选取,在这里有个问题,如果选取的中心点有重复了(因为代码中未作判断),就会出bug,但是概率基本为0,但还是做下说明,免得大家运行时偶尔出问题,如果出问题了就重新运行下就ok。