核心决策理论:选择概率最高的一类作为决策.即:在出现一个需要分类的新点时,我们只需要计算这个点:max(p(c1|x,y),p(c2|x,y),p(c3| x,y)...p(cn |x,y))。其对应的最大概率标签,就是这个新点的分类。
package baseNaiveBayesian;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public abstract class NaiveBayesianBase {
protected ArrayList<ArrayList<String>> trainingSet;
public NaiveBayesianBase() {
// TODO 自动生成的构造函数存根
trainingSet=new ArrayList<ArrayList<String>>();
}
public abstract int inputTrainingSet() throws IOException;
public abstract int readTrainingSet(String path) throws IOException;
public static Map<String, ArrayList<ArrayList<String>>> dataClassification(ArrayList<ArrayList<String>> data) {//按照最后一个值分类
Map<String, ArrayList<ArrayList<String>>> map=new HashMap<String, ArrayList<ArrayList<String>>>();
ArrayList<String> line=null;
String word="";
for(int i=0;i<data.size();++i) {
line=data.get(i);
word=line.get(line.size()-1);
if(map.containsKey(word)) map.get(word).add(line);
else {
ArrayList<ArrayList<String>> newLine=new ArrayList<ArrayList<String>>();
newLine.add(line);
map.put(word, newLine);
}
}
Object[] c=map.keySet().toArray();
//for(int i=0;i<c.length;++i) System.out.println(c[i].toString()+","+map.get(c[i]).size());
return map;
}
public String predictClassification(ArrayList<String> testSet) {
Map<String, ArrayList<ArrayList<String>>> doc=dataClassification(trainingSet);
//保存训练集属性于数组中
Object[] classificationAttributes=doc.keySet().toArray();
double maxP=0.00;
int maxPIndex=-1;
for(int i=0;i<doc.size();++i) {
String word=classificationAttributes[i].toString();
ArrayList<ArrayList<String>> line=doc.get(word);
BigDecimal b1=new BigDecimal(Double.toString(line.size()));
BigDecimal b2=new BigDecimal(Double.toString(trainingSet.size()));
double pClassification=b1.divide(b2, 3, RoundingMode.HALF_UP).doubleValue();
int cn=trainingSet.get(0).size()-1>testSet.size()?testSet.size():trainingSet.get(0).size()-1;
for(int k=0;k<cn;++k) {
double pCA=pOfClassificationAttributes(testSet.get(k), k,classificationAttributes[i].toString());
if(pCA<=0.00) pCA=1.0/doc.get(classificationAttributes[i].toString()).size();
pClassification=new BigDecimal(Double.toString(pClassification)).multiply(new BigDecimal(Double.toString(pCA))).doubleValue();
}
if(pClassification>maxP) {
maxP=pClassification;
maxPIndex=i;
}
}
//System.out.println(classificationAttributes[maxPIndex].toString());
return classificationAttributes[maxPIndex].toString();
}
public double pOfClassificationAttributes(String attribute,int index,String classificationclass) {
double p=0.0;
int count=0;
int total=0;
for(int i=0;i<trainingSet.size();++i) {
if(trainingSet.get(i).get(trainingSet.get(i).size()-1).equals(classificationclass)) {
++total;
if(trainingSet.get(i).get(index).equals(attribute)) ++count;
}
}
BigDecimal b1=new BigDecimal(Double.toString(count));
BigDecimal b2=new BigDecimal(Double.toString(total));
p=b1.divide(b2, 3, RoundingMode.HALF_UP).doubleValue();
//System.out.println(total+" "+count+"\t"+attribute+"\t"+classificationclass);
return p;
}
public void reportModel(double d) { //比例为d的数据做测试集
if(d<0.0||d>1.0) return;
ArrayList<ArrayList<String>> testSet=new ArrayList<ArrayList<String>>();
int testSetCount=(int) (trainingSet.size()*d);
for(int i=0;i<testSetCount;++i) testSet.add(trainingSet.remove((int)(Math.random()*(trainingSet.size()-1))));
Map<String, Integer> counts=new HashMap<String, Integer>();
Map<String, Integer> real=new HashMap<String, Integer>();
Map<String, ArrayList<ArrayList<String>>> doc=dataClassification(testSet);
Object[] objects=doc.keySet().toArray();
for(int i=0;i<objects.length;++i) {
real.put(objects[i].toString(), doc.get(objects[i]).size());
counts.put(objects[i].toString(), 0);
}
for(int i=0;i<testSet.size();++i) {
String key=predictClassification(testSet.get(i));
counts.replace(key, counts.get(key)+1);
}
double p=0.0;
for(int i=0;i<objects.length;++i)
p+=((double)Math.abs(real.get(objects[i])-counts.get(objects[i])))/testSetCount;
System.out.println("模型准确率为:"+(1.0-p));
for(int i=0;i<testSetCount;++i) trainingSet.add(testSet.get(i));
}
public void reportModelSelf() {
Map<String, Integer> counts=new HashMap<String, Integer>();
Map<String, Integer> real=new HashMap<String, Integer>();
Map<String, ArrayList<ArrayList<String>>> doc=dataClassification(trainingSet);
Object[] objects=doc.keySet().toArray();
for(int i=0;i<objects.length;++i) {
real.put(objects[i].toString(), doc.get(objects[i]).size());
counts.put(objects[i].toString(), 0);
}
for(int i=0;i<trainingSet.size();++i) {
String key=predictClassification(trainingSet.get(i));
counts.replace(key, counts.get(key)+1);
}
double p=0.0;
for(int i=0;i<objects.length;++i)
p+=((double)Math.abs(real.get(objects[i])-counts.get(objects[i])))/trainingSet.size();
System.out.println("模型准确率为:"+(1.0-p));
}
}
`
//病人数据资料分类
package classification;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import baseNaiveBayesian.NaiveBayesianBase;
public class PatientClassification extends NaiveBayesianBase {
@Override
public int inputTrainingSet() throws IOException {
// TODO 自动生成的方法存根
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String str = "";
while (!(str = reader.readLine()).equals("")) {
String[] tokenizer = str.split(",");
ArrayList<String> s = new ArrayList<String>();
for(int i=0;i<tokenizer.length;i++){
s.add(tokenizer[i]);
}
trainingSet.add(dataDeal(s));
}
return 0;
}
@Override
public int readTrainingSet(String path) throws IOException {
// TODO 自动生成的方法存根
File file=new File(path);
if(!file.exists()||!file.isFile()) {
System.out.println(file.getAbsolutePath());
return -1;
}
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
String str = "";
while ((str=reader.readLine())!=null) {
String[] tokenizer = str.split(",");
ArrayList<String> s = new ArrayList<String>();
for(int i=0;i<tokenizer.length;i++){
s.add(tokenizer[i]);
}
trainingSet.add(dataDeal(s));
}
reader.close();
return 0;
}
private static ArrayList<String> dataDeal(ArrayList<String> line) {
ArrayList<String> newLine=new ArrayList<String>();
int temp=-1;
double tempDouble=0.0;
//Age
switch ((Integer.parseInt(line.get(0))+2)/5) {
case 0:case 1:case 2:case 3:case 4:case 5:case 6:newLine.add("1");break;
case 7:newLine.add("2");break;
case 8:newLine.add("3");break;
case 9:newLine.add("4");break;
case 10:newLine.add("5");break;
case 11:newLine.add("6");break;
case 12:newLine.add("7");break;
default:newLine.add("-1");break;
}
//Gender
newLine.add(line.get(1));
//BMI
temp=Integer.parseInt(line.get(2));
if(temp<18) newLine.add("1");
else if(temp<25) newLine.add("2");
else newLine.add(String.valueOf(temp/5-2));
//Fever,Nausea/Vomiting,Headache,Diarrhea,Fatigue & Bone ache,Jaundice,Epigastria pain 7 Absent/Present
for(int i=0;i<7;++i) newLine.add(line.get(i+3));
//WBC 10
temp=Integer.parseInt(line.get(10));
if(temp<4000) newLine.add("1");
else if(temp<11000) newLine.add("2");
else newLine.add("3");
//RBC 11
tempDouble=Double.parseDouble(line.get(11));
if(tempDouble<3000000.00) newLine.add("1");
else if(tempDouble<5000000.00) newLine.add("2");
else newLine.add("3");
//HGB 12
temp=Integer.parseInt(line.get(12));
if(newLine.get(1).equals("1")) {
if(temp<14) newLine.add("1");
else if(temp<=17) newLine.add("2");
else newLine.add("3");
}
else {
if(temp<12) newLine.add("1");
else if(temp<=15) newLine.add("2");
else newLine.add("3");
}
//Plat
tempDouble=Double.parseDouble(line.get(13));
if(tempDouble<100000.00) newLine.add("1");
else if(tempDouble<255000) newLine.add("2");
else newLine.add("3");
//AST1,ALT1,ALT4,ALT12,ALT24,ALT36,ALT48,ALT after 24 8
for(int i=0;i<8;++i) {
tempDouble=Double.parseDouble(line.get(i+14));
if(tempDouble<20.00) newLine.add("1");
else if(tempDouble<=40.00) newLine.add("2");
else newLine.add("3");
}
//RNA Base,RNA 4,RNA 12,RNA EOT,RNA EF 22 5
for(int i=0;i<5;++i) {
tempDouble=Double.parseDouble(line.get(22+i));
if(tempDouble<=5.00) newLine.add("1");
else newLine.add("2");
}
//Baseline Histological Grading 27
newLine.add(line.get(27));
//Baseline Histological 28 分类4类
if(line.size()>newLine.size()) newLine.add(line.get(line.size()-1));
return newLine;
}
public static void main(String[] args) {
PatientClassification patientClassification=new PatientClassification();
try {
patientClassification.readTrainingSet("patientData.txt");
patientClassification.reportModelSelf();
patientClassification.reportModel(0.10);
} catch (IOException e) {
// TODO 自动生成的 catch 块
e.printStackTrace();
}
}
}
//参考数据:[测试数据](https://archive.ics.uci.edu/ml/datasets/Hepatitis+C+Virus+%28HCV%29+for+Egyptian+patients)