ID3算法实现
package id3;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import tool.TreeNode;
import tool.Matrix_2D;
import tool.ReadData;
public class ID3Tool {
public Matrix_2D<String> readData(String path) throws IOException {
return new Matrix_2D<String>(ReadData.readDataFile(path));
}
public static double calShannonEntropy(Matrix_2D<String> ds) {
int m = ds.getRowDimension();
int n = ds.getColDimension();
String currentLabel = "";
double shannonEnt = 0;
double rate = 0;
HashMap<String,Integer> labelCounts = new HashMap<String, Integer>();
//统计各类出现次数
for(int i=0;i<m;i++){
currentLabel = ds.get(i,n-1);
if(!labelCounts.containsKey(currentLabel))
labelCounts.put(currentLabel,0);
labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
}
//计算整体香农熵
for(String key:labelCounts.keySet()){
rate =labelCounts.get(key)/(float)m;
shannonEnt -= rate*Math.log(rate)/Math.log(2);
}
return shannonEnt;
}
private Matrix_2D<String> splitDataSet(Matrix_2D<String> dataSet,int axis,String value){
Matrix_2D<String> retDataSet = new Matrix_2D<String>();
int r=dataSet.getRowDimension();
int c=dataSet.getColDimension();
for(int i=0;i<r;++i) {
if(dataSet.get(i, axis).equals(value)) {
ArrayList<String> tmp=new ArrayList<String>();
for(int j=0;j<c;++j) tmp.add(dataSet.get(i, j));
tmp.remove(axis);
retDataSet.putLine(tmp);
}
}
return retDataSet;
}
private int chooseBestFeatureToSplit(Matrix_2D<String> dataSet){
int featureNum=dataSet.getColDimension()-1,row=dataSet.getRowDimension();
double baseShannonEntropy=calShannonEntropy(dataSet);
double bestInfoGain=0.0;
int bestFeature=-1;
for(int col=0;col<featureNum;++col) {
//提取一列
Set<String> tmpFeatureSet=new HashSet<String>();
for(int i=0;i<row;++i) tmpFeatureSet.add(dataSet.get(i, col));
double pro=0.0;
double newShannonEntropy=0.0;
for(String str : tmpFeatureSet) {
Matrix_2D<String> subDataSet=splitDataSet(dataSet, col, str);
pro=subDataSet.getRowDimension()/(double)row;
newShannonEntropy+=pro*calShannonEntropy(subDataSet);
}
double infoGain=baseShannonEntropy-newShannonEntropy;
if(infoGain>bestInfoGain) {
bestInfoGain=infoGain;
bestFeature=col;
}
}
return bestFeature;
}
private String majorityClassificationCount(String[] labels) {
Map<String, Integer> labelCount=new HashMap<String, Integer>();
for(String s : labels) {
if(!labelCount.containsKey(s)) labelCount.put(s,0);
labelCount.put(s,labelCount.get(s)+1);
}
int count=-1;
String t="";
for(String s : labelCount.keySet()) {
if(labelCount.get(s)>count) {
count=labelCount.get(s);
t=s;
}
}
return t;
}
public TreeNode creaDecTree(Matrix_2D<String> dataSet,String[] features) {
int row=dataSet.getRowDimension(),col=dataSet.getColDimension();
String[] labelList=new String[row];//分类类别
for(int i=0;i<row;++i) labelList[i]=dataSet.get(i, col-1);
int num=0;
for(String str : labelList)
if(str.equals(labelList[0])) ++num;
if(num==labelList.length) {
System.out.println("一类..."+features.length);
return new TreeNode(labelList[0],null);
}
if(col==1) {
System.out.println("一列"+features.length);
return new TreeNode(majorityClassificationCount(labelList),null);
}
int bestFeature=chooseBestFeatureToSplit(dataSet);
String bestFeatureLabel=features[bestFeature];
String[] subFeatures=subArray(features, bestFeatureLabel);
Set<String> uniqFeatureVals=new HashSet<String>();
for(int i=0;i<row;++i) uniqFeatureVals.add(dataSet.get(i, bestFeature));
Map<String, TreeNode> c=new HashMap<String, TreeNode>();
if(uniqFeatureVals.size()==1) {
for(String x : uniqFeatureVals) c.put(x, creaDecTree(splitDataSet(dataSet, bestFeature, x+x), subFeatures));
}
else {
for(String x : uniqFeatureVals) c.put(x, creaDecTree(splitDataSet(dataSet, bestFeature, x), subFeatures));
}
return new TreeNode(bestFeatureLabel,c);
}
private String[] subArray(String[] original,String str) {
String[] subArray=new String[original.length-1];
int k=0;
for(String s : original) {
if(!s.equals(str)) subArray[k++]=s;
}
return subArray;
}
public String classification(TreeNode tree,String[] features,ArrayList<String> sample) {
while(tree!=null&&tree.getChildren()!=null) {
try {
/*for(String s : tree.getChildren().keySet()) {
System.out.print(s+"\t");
}
System.out.println("\n"+tree.getElement()+"\t"+sample.get(getIndex(features, (String)tree.getElement())));
*/
tree=tree.getChildren().get(sample.get(getIndex(features, (String)tree.getElement())));
} catch (Exception e) {
// TODO: handle exception
e.printStackTrace();
System.out.println("..............................");
return "no such classification";
}
}
if(tree==null) {
return "no such classification";
}
return tree.getElement();
}
private int getIndex(String[] arr,String s) {
for(int i=0;i<arr.length;++i)
if(arr[i].equals(s)) return i;
return-1;
}
public static void main(String[] args) throws IOException {//AutismAdultDataPlus.txt/
ID3Tool dthTool=new ID3Tool();//AutismAdultDataPlus.txt/StudentAcademicsPerformance.txt
Matrix_2D<String> trainingSet=dthTool.readData("AutismAdultDataPlus.txt");
String[] features=new String[trainingSet.getColDimension()-1];
for(int i=0;i<features.length;++i)
features[i]="特征"+String.valueOf(i);
TreeNode tree=dthTool.creaDecTree(trainingSet, features);
int num=0;
final int row=trainingSet.getRowDimension(),col=trainingSet.getColDimension();
for(int i=0;i<row;++i) {
String tmp=dthTool.classification(tree, features, trainingSet.get(i));
if(tmp.equals(trainingSet.get(i).get(col-1))) {
++num;
}
System.out.println(i+"\t"+trainingSet.get(i).get(col-1)+"\t"+tmp);
}
System.out.println("分类精度:"+(num/(double)row));
}
}
工具类:TreeNode.java
package tool;
import java.util.Map;
public class TreeNode {
private String element;
private Map<String, TreeNode> children;
public TreeNode() {
// TODO 自动生成的构造函数存根
}
public TreeNode(String e,Map<String, TreeNode> c) {
// TODO 自动生成的构造函数存根
element=e;
children=c;
}
public Map<String, TreeNode> getChildren() {
return children;
}
public String getElement() {
return element;
}
}
工具类:Matrix_2D.java
package tool;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
public class Matrix_2D<T> {
ArrayList<ArrayList<T>> data;
public Matrix_2D() {
// TODO 自动生成的构造函数存根
data=new ArrayList<ArrayList<T>>();
}
public Matrix_2D(ArrayList<ArrayList<T>> d) {
data=new ArrayList<ArrayList<T>>();
for(ArrayList<T> val : d)
this.putLine(val);
}
public void putLine(ArrayList<T> line) {
ArrayList<T> tmp=new ArrayList<T>();
for(T t : line) tmp.add(t);
data.add(tmp);
}
public int getRowDimension() {
return data.size();
}
public int getColDimension() {
if(getRowDimension()==0) return 0;
return data.get(0).size();
}
public ArrayList<T> get(int i) {
return data.get(i);
}
public T get(int i,int j) {
return data.get(i).get(j);
}
public T remove(int i,int j) {
return data.get(i).remove(j);
}
public ArrayList<T> remove(int index) {
return data.remove(index);
}
public static String[] subArray(String[] original,String str) {
String[] subArray=new String[original.length-1];
int k=0;
for(String s : original) {
if(!s.equals(str)) subArray[k++]=s;
}
return subArray;
}
public static ArrayList<String> copyArrayList(ArrayList<String> data) {
ArrayList<String> d=new ArrayList<String>();
for(String s : data) d.add(s);
return d;
}
public static String majority(ArrayList<String> labels) {
Map<String, Integer> labelCount=new HashMap<String, Integer>();
for(String s : labels) {
if(!labelCount.containsKey(s)) labelCount.put(s,0);
labelCount.put(s,labelCount.get(s)+1);
}
int count=-1;
String t="";
for(String s : labelCount.keySet()) {
if(labelCount.get(s)>count) {
count=labelCount.get(s);
t=s;
}
}
return t;
}
}
工具类:ReadData.java
package tool;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
public class ReadData {
public static ArrayList<ArrayList<String>> readDataFile(String path) throws IOException {
ArrayList<ArrayList<String>> trainingSet=new ArrayList<ArrayList<String>>();
File file=new File(path);
if(!file.exists()||!file.isFile()) {
System.out.println(file.getAbsolutePath());
return null;
}
BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
String str = "";
while ((str=reader.readLine())!=null) {
String[] tokenizer = str.split(",");
ArrayList<String> s = new ArrayList<String>();
for(int i=0;i<tokenizer.length;i++){
s.add(tokenizer[i]);
}
trainingSet.add(s);
}
reader.close();
//打乱数据集
for(int i=0;i<trainingSet.size();++i) {
int t=(int) ((trainingSet.size()-i)*Math.random());
trainingSet.add(trainingSet.remove(t));
}
return trainingSet;
}
}
参考:
决策树分类器-Java实现
决策树分类算法:C4.5算法