KNN算法就是把待分类数据放在训练集里找出离他最近的K个元素(欧氏距离),然后看看其中哪个类最多,就将这个元素分为这个类。在本实验中,使用数字数据集。每个数字含有一个二维数组表示其中的像素点,可以认为拥有M*N个特征,只不过每个特征只有0和1两种值,表示该像素点是否绘制。
将下载的训练集和测试集放在项目根目录下,因为测试集中每个元素也是已标记数据,所以每次分类后可以判断分类是否正确,从而得出一个正确率。
在拿到待分类元素的K个邻居后,最简单的处理是每个邻居具有相等的投票权,考虑增大离得近的元素的影响力,也就是为他们的投票权设置权值。这里我的权值设置是离得最近的具有K票,第二近的具有K-1票,依次递减,比较容易理解,直接看代码。
1.封装的数字类:
public class Number { private int[][] data=new int[32][32]; private int kind; public int[][] getData() { return data; } public void setData(int[][] data) { this.data = data; } public int getKind() { return kind; } public void setKind(int kind) { this.kind = kind; } }2.test类:
import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.util.ArrayList; import java.util.List; public class Test { private int k; private List<Number> testDatas; public void putTestData(String path){//放入测试数据 File folder=new File(path); File[] files=folder.listFiles(); testDatas=new ArrayList<>(); for(File file:files){ testDatas.add(txt2Number(file)); } } public void work(String path){//开始测试 File folder=new File(path); File[] files=folder.listFiles(); int numberAll=0;//测试总数 int numCorrect=0;//测试正确数 double result=0;//正确率 for(File file:files){ Number num=txt2Number(file); int[] minDistances=new int[k]; int[] resultKinds=new int[k]; for (int i = 0; i < k; i++) { minDistances[i]=Integer.MAX_VALUE; resultKinds[i]=0; } for(Number nu:testDatas){ int currentDis=calcu(num, nu); int currentKind=nu.getKind(); for (int i = 0; i < k; i++) {//将当前测试数据与邻居数组中每一个进行比对,看看是否可以替换掉一个 if (currentDis<minDistances[i]) { resultKinds[i]=currentKind; minDistances[i]=currentDis; break; } } } int []kinds=new int[10];//10个类别的个数 for (int i = 0; i < k; i++) { kinds[resultKinds[i]]+=add(minDistances, i);//加权后累加 } int resultKind=0; int resultKindNum=0; for (int i = 0; i < 10; i++) { if (kinds[i]>resultKindNum) { resultKind=i; resultKindNum=kinds[i]; } } numberAll++; if (resultKind==num.getKind()) { numCorrect++; } } result=((double)(numCorrect*100))/numberAll; System.out.println("k是:"+getK()+" 测试总数:"+numberAll+" " + "正确数:"+numCorrect+" 正确率"+result); } public void workOne(String path){//测试单个 File fileTest=new File(path); Number num=txt2Number(fileTest); int[] minDistances=new int[k]; int[] resultKinds=new int[k]; for (int i = 0; i < k; i++) { minDistances[i]=Integer.MAX_VALUE; resultKinds[i]=0; } for(Number nu:testDatas){ int currentDis=calcu(num, nu); int currentKind=nu.getKind(); for (int i = 0; i < k; i++) {//将当前测试数据与邻居数组中每一个进行比对,看看是否可以替换掉一个 if (currentDis<minDistances[i]) { resultKinds[i]=currentKind; minDistances[i]=currentDis; break; } } } int []kinds=new int[10];//10个类别的个数 for (int i = 0; i < k; i++) { kinds[resultKinds[i]]++; } int resultKind=0; int resultKindNum=0; for (int i = 0; i < 10; i++) { if (kinds[i]>resultKindNum) { resultKind=i; resultKindNum=kinds[i]; } } System.out.println("识别文件"+path+"为:"+resultKind+" 实际类型为:"+num.getKind()); } public int calcu(Number a,Number b){//计算两张图的欧氏距离,为了简化计算不开根号 int result=0; for(int i=0;i<32;i++){ for (int j = 0; j < 32; j++) { int[][] d1=a.getData(); int[][] d2=b.getData(); int dis=d1[i][j]-d2[i][j]; result+=dis*dis; } } return result; } public int getK() { return k; } public void setK(int k) { this.k = k; } public Number txt2Number(File file){//txt文件转Number对象 Number num=new Number(); int[][] data=new int[32][32]; String fileName=file.getName(); int kind =Integer.valueOf(fileName.substring(0,1)); num.setKind(kind); try { BufferedReader reader=new BufferedReader(new FileReader(file)); String s=null; for (int i = 0; i < 32; i++) { s=reader.readLine(); for (int j = 0; j < 32; j++) { data[i][j]=Integer.valueOf(s.substring(j, j+1)); } } } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } num.setData(data); return num; } public int add(int []a,int index){//获取这个邻居元素的权值 int re=0; for (int i = 0; i < a.length; i++) { if (a[index]<=a[i]) { re++; } } return re; } public static void main (String[] args) { Test test=new Test(); test.putTestData("testDigits"); test.setK(10); test.work("trainingDigits"); test.setK(1); test.work("trainingDigits"); } }
实验结果: