最近看随机游走算法时,遇到了聚类算法。 结合当时参考的论文和一些博客。 整理了下思维,写了下面的算法。
package kMeans;
import java.util.Arrays;/**
* @ClassName: KMeansM
* @Description:k-Means算法,聚类算法
实现步骤: 1. 首先是随机获取总体中的K个元素作为总体的K个中心;
2. 接下来对总体中的元素进行分类,每个元素都去判断自己到K个中心的距离,并归类到最近距离中心去;
3. 计算每个聚类的平均值,并作为新的中心点
4. 重复2,3步骤,直到这k个中线点不再变化(收敛了),或执行了足够多的迭代
* @author: muliming
* @date: 2017年11月21日 下午11:12:11
* @Copyright: 2017
*/
public class KMeansM {
//定义模拟数据源 一个三维的点 选择18个点
private static double[][] DATA = {
{5.1,3.5,1.4},{4.9,3.0,1.4},{4.7,3.2,1.3},{4.6,3.1,1.5},{5.0,3.6,1.4},{5.4,3.9,1.7},
{4.6,3.4,1.4},{5.0,3.4,1.5},{4.4,2.9,1.4},{4.9,3.1,1.5},{5.4,3.7,1.5},{4.8,3.4,1.6},
{4.8,3.0,1.4},{4.3,3.0,1.1},{5.8,4.0,1.2},{5.7,4.4,1.5},{5.4,3.9,1.3},{5.1,3.5,1.4},
{5.7,3.8,1.7},{5.1,3.8,1.5},{5.4,3.4,1.7},{5.1,3.7,1.5},{4.6,3.6,1.0},{5.1,3.3,1.7},
{4.8,3.4,1.9},{5.0,3.0,1.6},{5.0,3.4,1.6},{5.2,3.5,1.5},{5.2,3.4,1.4},{4.7,3.2,1.6},
{4.8,3.1,1.6},{5.4,3.4,1.5},{5.2,4.1,1.5},{5.5,4.2,1.4},{4.9,3.1,1.5},{5.0,3.2,1.2},
{5.5,3.5,1.3},{4.9,3.1,1.5},{4.4,3.0,1.3},{5.1,3.4,1.5},{5.0,3.5,1.3},{4.5,2.3,1.3},
{4.4,3.2,1.3},{5.0,3.5,1.6},{5.1,3.8,1.9},{4.8,3.0,1.4},{5.1,3.8,1.6},{4.6,3.2,1.4},
{5.3,3.7,1.5},{5.0,3.3,1.4}};
public int k;//选择k个中心点进行聚合
public int[][] tempClusters;//记录每个中心点下点的索引号
public int[] elementsInCenters; //记录每个中心点所属的各自的类的个数
public double[][] centers; //中心点
public int[] memberShip; //记录每个点的中心点的索引号 为了后面的比较方便
//有参构成,初始化中心点的个数
public KMeansM(int k){
this.k = k; //定义中心点的个数
}
//--主函数--------------------------------------------------------
public static void main(String[] args) {
KMeansM kmeansM = new KMeansM(3); //初始化点数
String lastMembership = "";
String nowMembership = "";
int i=0;
kmeansM.firstCenters();//初试化中心点
System.out.println("第一次选取得中心点为:");
for(int n=0;n<kmeansM.centers.length;n++){
System.out.print(Arrays.toString(kmeansM.centers[n])+",");
}
System.out.println();
boolean isEnd=true;
while(isEnd){
i++;
kmeansM.calMemberShip(); //归类 寻找每个点所属的中心点,并记录每个中心点自己最终含有的点的总数
nowMembership = Arrays.toString(kmeansM.memberShip);//把当前的赋值,记录
if(nowMembership.equals(lastMembership)){
System.out.println("");
System.out.println("一共聚类了 "+(i-1)+" 次!");
for(int n=0;n<kmeansM.centers.length;n++){
System.out.print(Arrays.toString(kmeansM.centers[n])+",");
}
isEnd=false;
}else{
kmeansM.calNewCenters();
lastMembership = nowMembership;
System.out.println("第 "+i+" 次聚合");
for(int n=0;n<kmeansM.centers.length;n++){
System.out.print(Arrays.toString(kmeansM.centers[n])+",");
}
System.out.println();
System.out.println();
}
}
}
//-------工具方法---------------------------------------------------------
/**
* @Title: firstCenters
* @Description: 初试化中心点,选取前k个点为初试中心点
* @return: 创建初始化中心点
* @throws
*/
public void firstCenters(){
centers = new double[k][DATA[0].length];
for(int i=0;i<k;i++){
for(int j=0;j<DATA[i].length;j++){
centers[i][j]=DATA[i][j];
}
}
}
/**
* @Title: manhattanDistince
* @Description: 计算临近距离 每一串数据与选出的中心点的距离,
* @param: @param paraFirstData 第一个点
* @param: @param paraSecondData 第二个点
* @return: double 返回两个点之间的距离
* @throws
*/
public double manhattanDistince(double[] paraFirstData,double[] paraSecondData){
double tempDistince = 0;
if((paraFirstData!=null && paraSecondData!=null) && paraFirstData.length==paraSecondData.length){
for(int i=0;i<paraFirstData.length;i++){
tempDistince += Math.abs(paraFirstData[i] - paraSecondData[i]);
}
}else{
System.out.println("firstData 与 secondData 数据结构不一致");
}
return tempDistince;
}
/**
* @Title: calNewCenters
* @Description: 生成新的中心点 ,每次中心点取该类中的平均值
*/
public void calNewCenters(){
double[][] tempCenters=new double[k][3]; //中心点
//求和
for(int i=0;i<k;i++){
for(int j=0;j<elementsInCenters[i];j++){
for(int k=0;k<DATA[i].length;k++){
tempCenters[i][k]+=DATA[tempClusters[i][j]][k];
}
}
}
//取平均值
for(int i=0;i<centers.length;i++){
for(int j=0;j<DATA[0].length;j++){
if(elementsInCenters[i]!=0){
tempCenters[i][j] /= elementsInCenters[i];
}else{
tempCenters[i][j] = centers[i][j];
}
}
}
centers=tempCenters;
}
/**
* @Title: calMemberShip
* @Description: 寻找每个点所属的中心点,并记录每个中心点自己最终含有的点的总数
* @return: 记录数组,和记录总数的数组
*/
public void calMemberShip(){
memberShip = new int[DATA.length];//记录每串数据的中心点在中心点数组中的索引
tempClusters = new int[k][DATA.length];//记录每个中心点 下点的索引号
elementsInCenters = new int[k];//记录每个中心点的类的个数
for(int j=0;j<DATA.length;j++){
double currentDistance = Double.MAX_VALUE;//比较变量
int currentIndex = -1;//索引位置
double[] item = DATA[j];
int i;
for(i=0;i<k;i++){//和中心点做比较
double[] tempCentersValue = centers[i]; //中心点
double distance = this.manhattanDistince(item, tempCentersValue);
if(distance<currentDistance){
currentDistance = distance;
currentIndex = i; //记录当前点的中心点
}
}
memberShip[j]=currentIndex;
tempClusters[currentIndex][elementsInCenters[currentIndex]] = j;// 把索引号存入自己的腰包
elementsInCenters[currentIndex]++;
}
}
}
(思路可能有相似的,毕竟算法就这么点代码。有参考一些,但本算法有一改进)