贝叶斯算法,简单地说就是比较测试数据是各个类别的概率,概率最大的就判断为这个数据的类别。具体来说就是一个条件概率的计算,测试数据是X,某个类别是C,那么也就是求P(C|X)最大的C,等于P(X|C)P(C)/P(X),因为P(X)不变,所以不必引入计算。也就是说,比较的是P(X|C)P(C)。进一步,P(C)=C类数量/训练集总数,其中训练集总数不变,所以也可以忽略。P(X|C)=P(X1 | C).P(X2 | C).......也就是C类中X各个特征的值的条件概率。
首先需要一个测试集,然后求出各个类的数量和各个类的各个特征的各个值的条件概率(三层数组),接下来就可以对测试数据进行分类,使用上面的数据进行计算。
在本实验中,使用的是汽车数据集,分为四类。我将汽车封装为一个类,然后将数据集内容存储在一个txt中(放在项目根目录下,叫做a.txt),每一行对应一个汽车数据元素,然后在主类test类中还提供了txt转成汽车数组的方法。根据实验要求取数据集的部分作为训练集,全集作为测试集。由于全集数据其实都是已分类的数据,所以每次分完类后可以直接判断本次分类是否正确,从而得出正确率。
1.汽车类:
public class Car { private int kind,buying,maint,door,Persons,Lug_boot,Safety; public int getKind() { return kind; } public void setKind(int kind) { this.kind = kind; } public int getBuying() { return buying; } public void setBuying(int buying) { this.buying = buying; } public int getMaint() { return maint; } public void setMaint(int maint) { this.maint = maint; } public int getDoor() { return door; } public void setDoor(int door) { this.door = door; } public int getPersons() { return Persons; } public void setPersons(int persons) { Persons = persons; } public int getLug_boot() { return Lug_boot; } public void setLug_boot(int lug_boot) { Lug_boot = lug_boot; } public int getSafety() { return Safety; } public void setSafety(int safety) { Safety = safety; } }
2.test类:
import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; public class Test { //将特征的各个值映射为Int,从0开始 public Car[] cars=new Car[1728];//数据集合 public int[] kindNumber=new int[4];//四个类别各自在训练集中的数目(为了简化计算直接使用数目) public double[][][] kindP=new double[4][6][4];//四个类别的各个特点的各个值的条件概率 int numberAll=0; int numberCotrrect=0;//总测试数和正确数 int numberTest=0;//测试集数目 public void getCars(String path){//读取数据文件转为cars File file=new File(path); try { BufferedReader reader=new BufferedReader(new FileReader(file)); String s=null; int number=0; while ((s=reader.readLine())!=null) { cars[number++]=txt2Car(s); } } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } } public Car txt2Car(String s){//一行文本转Car Car car=new Car(); int start=0; int end=0; //1 while(!s.substring(end, end+1).equals(",")){ end++; } String buyings=s.substring(start,end); int buying=0; if (buyings.equals("vhigh")) { buying=0; } else { if (buyings.equals("high")) { buying=1; } else { if (buyings.equals("med")) { buying=2; } else { buying=3; } } } car.setBuying(buying); end++; start=end; //2 while(!s.substring(end, end+1).equals(",")){ end++; } String maints=s.substring(start,end); int maint=0; if (maints.equals("vhigh")) { maint=0; } else { if (maints.equals("high")) { maint=1; } else { if (maints.equals("med")) { maint=2; } else { maint=3; } } } car.setMaint(maint); end++; start=end; //3 while(!s.substring(end, end+1).equals(",")){ end++; } String doors=s.substring(start,end); int door=0; if (doors.equals("2")) { door=0; } else { if (doors.equals("3")) { door=1; } else { if (doors.equals("4")) { door=2; } else { door=3; } } } car.setDoor(door); end++; start=end; //4 while(!s.substring(end, end+1).equals(",")){ end++; } String persons=s.substring(start,end); int person=0; if (persons.equals("2")) { person=0; } else { if (persons.equals("4")) { person=1; } else { person=2; } } car.setPersons(person); end++; start=end; //5 while(!s.substring(end, end+1).equals(",")){ end++; } String lugs=s.substring(start,end); int lug=0; if (lugs.equals("small")) { lug=0; } else { if (lugs.equals("med")) { lug=1; } else { lug=2; } } car.setLug_boot(lug); end++; start=end; //6 while(!s.substring(end, end+1).equals(",")){ end++; } String safes=s.substring(start,end); int safe=0; if (safes.equals("low")) { safe=0; } else { if (safes.equals("med")) { safe=1; } else { safe=2; } } car.setSafety(safe); end++; start=end; //kind String kinds=s.substring(start,s.length()); int kind=0; if (kinds.equals("unacc")) { kind=0; } else { if (kinds.equals("acc")) { kind=1; } else { if (kinds.equals("good")) { kind=2; } else { kind=3; } } } car.setKind(kind); return car; } public void putTest(int n){//放入训练集进行训练,参数是训练集大小 numberTest=n; numberCotrrect=0; numberAll=0; double [][][] k=new double[4][6][4];//四个类别的各个特点的各个值的个数 for(int i=0;i<n;i++){ Car car=cars[i]; int kind=car.getKind(); int buying=car.getKind(); int maint=car.getMaint(); int door=car.getDoor(); int persons=car.getPersons(); int lug_Boot=car.getLug_boot(); int safety=car.getSafety(); kindNumber[kind]++; k[kind][0][buying]++; k[kind][1][maint]++; k[kind][2][door]++; k[kind][3][persons]++; k[kind][4][lug_Boot]++; k[kind][5][safety]++; } // for(int i=0;i<4;i++){ // System.out.println("第"+i+"类的数目为"+kindNumber[i]); // } for(int i=0;i<4;i++){//为kindP赋值 for(int j=0;j<6;j++){ for(int p=0;p<4;p++){ if (kindNumber[i]>0) { kindP[i][j][p]=k[i][j][p]/kindNumber[i]; //System.out.println("第"+i+"类第"+j+"个特征的第"+p+"个值的概率是"+kindP[i][j][p]); } } } } } public void work(){ for(int i=0;i<cars.length;i++){ numberAll++; Car car=cars[i]; int kind=car.getKind(); int buying=car.getKind(); int maint=car.getMaint(); int door=car.getDoor(); int persons=car.getPersons(); int lug_Boot=car.getLug_boot(); int safety=car.getSafety(); int testKind=0; double testP=0; for(int j=0;j<4;j++){ double currentP=kindNumber[j]*kindP[j][0][buying] *kindP[j][1][maint]*kindP[j][2][door] *kindP[j][3][persons]*kindP[j][4][lug_Boot]*kindP[j][5][safety]; if (currentP>testP) { testP=currentP; testKind=j; } } if (testKind==kind) { numberCotrrect++; } } double result=((double)(numberCotrrect*100))/numberAll; System.out.println("测试集数:"+numberTest+" 测试正确数 :"+numberCotrrect+" 测试正确率:"+result+"%"); } public static void main(String []args){ Test test=new Test(); test.getCars("a.txt"); test.putTest(100); test.work(); test.putTest(200); test.work(); test.putTest(500); test.work(); test.putTest(700); test.work(); test.putTest(1000); test.work(); test.putTest(1350); test.work(); } }
可以看到main方法里每次测试都分两步,putTest指定全集的多少作为训练集(然后计算各个值),work遍历全集分类并统计结果。