山东大学模式识别实验(java)贝叶斯算法

贝叶斯算法,简单地说就是比较测试数据是各个类别的概率,概率最大的就判断为这个数据的类别。具体来说就是一个条件概率的计算,测试数据是X,某个类别是C,那么也就是求P(C|X)最大的C,等于P(X|C)P(C)/P(X),因为P(X)不变,所以不必引入计算。也就是说,比较的是P(X|C)P(C)。进一步,P(C)=C类数量/训练集总数,其中训练集总数不变,所以也可以忽略。P(X|C)=P(X1 | C).P(X2 | C).......也就是C类中X各个特征的值的条件概率。

首先需要一个测试集,然后求出各个类的数量和各个类的各个特征的各个值的条件概率(三层数组),接下来就可以对测试数据进行分类,使用上面的数据进行计算。

在本实验中,使用的是汽车数据集,分为四类。我将汽车封装为一个类,然后将数据集内容存储在一个txt中(放在项目根目录下,叫做a.txt),每一行对应一个汽车数据元素,然后在主类test类中还提供了txt转成汽车数组的方法。根据实验要求取数据集的部分作为训练集,全集作为测试集。由于全集数据其实都是已分类的数据,所以每次分完类后可以直接判断本次分类是否正确,从而得出正确率。

1.汽车类:

public class Car {
	private int kind,buying,maint,door,Persons,Lug_boot,Safety;

	public int getKind() {
		return kind;
	}

	public void setKind(int kind) {
		this.kind = kind;
	}

	public int getBuying() {
		return buying;
	}

	public void setBuying(int buying) {
		this.buying = buying;
	}

	public int getMaint() {
		return maint;
	}

	public void setMaint(int maint) {
		this.maint = maint;
	}

	public int getDoor() {
		return door;
	}

	public void setDoor(int door) {
		this.door = door;
	}

	public int getPersons() {
		return Persons;
	}

	public void setPersons(int persons) {
		Persons = persons;
	}

	public int getLug_boot() {
		return Lug_boot;
	}

	public void setLug_boot(int lug_boot) {
		Lug_boot = lug_boot;
	}

	public int getSafety() {
		return Safety;
	}

	public void setSafety(int safety) {
		Safety = safety;
	}
	
}

2.test类:

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;

public class Test {

	//将特征的各个值映射为Int,从0开始
	
	public Car[] cars=new Car[1728];//数据集合
	public int[] kindNumber=new int[4];//四个类别各自在训练集中的数目(为了简化计算直接使用数目)
	public double[][][] kindP=new double[4][6][4];//四个类别的各个特点的各个值的条件概率
	
	int numberAll=0;
	int numberCotrrect=0;//总测试数和正确数
	int numberTest=0;//测试集数目
	
	public void getCars(String path){//读取数据文件转为cars
		File file=new File(path);
		try {
			BufferedReader reader=new BufferedReader(new FileReader(file));
			String s=null;
			int number=0;
			while ((s=reader.readLine())!=null) {
				cars[number++]=txt2Car(s);
			}
		} catch (Exception e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
	}
	
	public Car txt2Car(String s){//一行文本转Car
		Car car=new Car();
		int start=0;
		int end=0;
		//1
		while(!s.substring(end, end+1).equals(",")){
			end++;
		}
		String buyings=s.substring(start,end);
		int buying=0;
		if (buyings.equals("vhigh")) {
			buying=0;
		} else {
			if (buyings.equals("high")) {
				buying=1;
			} else {
				if (buyings.equals("med")) {
					buying=2;
				} else {
					buying=3;
				}
			}
		}
		car.setBuying(buying);
		end++;
		start=end;
		//2
		while(!s.substring(end, end+1).equals(",")){
			end++;
		}
		String maints=s.substring(start,end);
		int maint=0;
		if (maints.equals("vhigh")) {
			maint=0;
		} else {
			if (maints.equals("high")) {
				maint=1;
			} else {
				if (maints.equals("med")) {
					maint=2;
				} else {
					maint=3;
				}
			}
		}
		car.setMaint(maint);
		end++;
		start=end;
		//3
		while(!s.substring(end, end+1).equals(",")){
			end++;
		}
		String doors=s.substring(start,end);
		int door=0;
		if (doors.equals("2")) {
			door=0;
		} else {
			if (doors.equals("3")) {
				door=1;
			} else {
				if (doors.equals("4")) {
					door=2;
				} else {
					door=3;
				}
			}
		}
		car.setDoor(door);
		end++;
		start=end;
		//4
		while(!s.substring(end, end+1).equals(",")){
			end++;
		}
		String persons=s.substring(start,end);
		int person=0;
		if (persons.equals("2")) {
			person=0;
		} else {
			if (persons.equals("4")) {
				person=1;
			} else {
				person=2;
			}
		}
		car.setPersons(person);
		end++;
		start=end;
		//5
		while(!s.substring(end, end+1).equals(",")){
			end++;
		}
		String lugs=s.substring(start,end);
		int lug=0;
		if (lugs.equals("small")) {
			lug=0;
		} else {
			if (lugs.equals("med")) {
				lug=1;
			} else {
				lug=2;
			}
		}
		car.setLug_boot(lug);
		end++;
		start=end;
		//6
		while(!s.substring(end, end+1).equals(",")){
			end++;
		}
		String safes=s.substring(start,end);
		int safe=0;
		if (safes.equals("low")) {
			safe=0;
		} else {
			if (safes.equals("med")) {
				safe=1;
			} else {
				safe=2;
			}
		}
		car.setSafety(safe);
		end++;
		start=end;
		//kind
		String kinds=s.substring(start,s.length());
		int kind=0;
		if (kinds.equals("unacc")) {
			kind=0;
		} else {
			if (kinds.equals("acc")) {
				kind=1;
			} else {
				if (kinds.equals("good")) {
					kind=2;
				} else {
					kind=3;
				}
			}
		}
		car.setKind(kind);
		return car;
	}
	
	public void putTest(int n){//放入训练集进行训练,参数是训练集大小
		numberTest=n;
		numberCotrrect=0;
		numberAll=0;
		double [][][] k=new double[4][6][4];//四个类别的各个特点的各个值的个数
		for(int i=0;i<n;i++){
			Car car=cars[i];
			int kind=car.getKind();
			int buying=car.getKind();
			int maint=car.getMaint();
			int door=car.getDoor();
			int persons=car.getPersons();
			int lug_Boot=car.getLug_boot();
			int safety=car.getSafety();
			kindNumber[kind]++;
			k[kind][0][buying]++;
			k[kind][1][maint]++;
			k[kind][2][door]++;
			k[kind][3][persons]++;
			k[kind][4][lug_Boot]++;
			k[kind][5][safety]++;
		}
//		for(int i=0;i<4;i++){
//			System.out.println("第"+i+"类的数目为"+kindNumber[i]);
//		}
		for(int i=0;i<4;i++){//为kindP赋值
			for(int j=0;j<6;j++){
				for(int p=0;p<4;p++){
					if (kindNumber[i]>0) {
						kindP[i][j][p]=k[i][j][p]/kindNumber[i];
						//System.out.println("第"+i+"类第"+j+"个特征的第"+p+"个值的概率是"+kindP[i][j][p]);
					}
				}
			}
		}
	}
	
	public void work(){
		for(int i=0;i<cars.length;i++){
			numberAll++;
			
			Car car=cars[i];
			int kind=car.getKind();
			int buying=car.getKind();
			int maint=car.getMaint();
			int door=car.getDoor();
			int persons=car.getPersons();
			int lug_Boot=car.getLug_boot();
			int safety=car.getSafety();
			
			int testKind=0;
			double testP=0;
			for(int j=0;j<4;j++){
				double currentP=kindNumber[j]*kindP[j][0][buying]
						*kindP[j][1][maint]*kindP[j][2][door]
						*kindP[j][3][persons]*kindP[j][4][lug_Boot]*kindP[j][5][safety];
				if (currentP>testP) {
					testP=currentP;
					testKind=j;
				}
			}
			if (testKind==kind) {
				numberCotrrect++;
			}
		}
		double result=((double)(numberCotrrect*100))/numberAll;
		System.out.println("测试集数:"+numberTest+"       测试正确数 :"+numberCotrrect+"    测试正确率:"+result+"%");
	}
	
	public static void main(String []args){
		Test test=new Test();
		test.getCars("a.txt");
		
		test.putTest(100);
		test.work();
		
		test.putTest(200);
		test.work();
		
		test.putTest(500);
		test.work();
		
		test.putTest(700);
		test.work();
		
		test.putTest(1000);
		test.work();
		
		test.putTest(1350);
		test.work();
	}
}


实验结果:

可以看到main方法里每次测试都分两步,putTest指定全集的多少作为训练集(然后计算各个值),work遍历全集分类并统计结果。

猜你喜欢

转载自blog.csdn.net/zhang___yong/article/details/79053523