决策树分类算法:ID3算法

【每次以信息增益最大的特征项Ai为节点建立决策树】

ID3算法实现

package id3;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import tool.TreeNode;
import tool.Matrix_2D;
import tool.ReadData;

public class ID3Tool {
	
	public Matrix_2D<String> readData(String path) throws IOException {
		return new Matrix_2D<String>(ReadData.readDataFile(path));
	}
	
	public static double calShannonEntropy(Matrix_2D<String> ds) {
		int m = ds.getRowDimension();
        int n = ds.getColDimension();
        String currentLabel = "";
        double shannonEnt = 0;
        double rate = 0;
        HashMap<String,Integer> labelCounts = new HashMap<String, Integer>();
        //统计各类出现次数
        for(int i=0;i<m;i++){
            currentLabel = ds.get(i,n-1);
            if(!labelCounts.containsKey(currentLabel))
                labelCounts.put(currentLabel,0);
            labelCounts.put(currentLabel,labelCounts.get(currentLabel)+1);
        }
        //计算整体香农熵
        for(String key:labelCounts.keySet()){
            rate =labelCounts.get(key)/(float)m;
            shannonEnt -= rate*Math.log(rate)/Math.log(2);
        }
        return shannonEnt;
	}
	
	 private Matrix_2D<String> splitDataSet(Matrix_2D<String> dataSet,int axis,String value){
		Matrix_2D<String> retDataSet = new Matrix_2D<String>();
		int r=dataSet.getRowDimension();
		int c=dataSet.getColDimension();
		for(int i=0;i<r;++i) {
			if(dataSet.get(i, axis).equals(value)) {
				ArrayList<String> tmp=new ArrayList<String>();
				for(int j=0;j<c;++j) tmp.add(dataSet.get(i, j));
				tmp.remove(axis);
				retDataSet.putLine(tmp);
			}
		}
	    return retDataSet;
	 }
	
	 private int chooseBestFeatureToSplit(Matrix_2D<String> dataSet){
		 int featureNum=dataSet.getColDimension()-1,row=dataSet.getRowDimension();
		 double baseShannonEntropy=calShannonEntropy(dataSet);
		 double bestInfoGain=0.0;
		 int bestFeature=-1;
		 for(int col=0;col<featureNum;++col) {
			 //提取一列
			 Set<String> tmpFeatureSet=new HashSet<String>();
			 for(int i=0;i<row;++i) tmpFeatureSet.add(dataSet.get(i, col));
			 double pro=0.0;
			 double newShannonEntropy=0.0;
			 for(String str : tmpFeatureSet) {
				 Matrix_2D<String> subDataSet=splitDataSet(dataSet, col, str);
				 pro=subDataSet.getRowDimension()/(double)row;
				 newShannonEntropy+=pro*calShannonEntropy(subDataSet);
			 }
			 double infoGain=baseShannonEntropy-newShannonEntropy;
			 if(infoGain>bestInfoGain) {
				 bestInfoGain=infoGain;
				 bestFeature=col;
			 }
		 }
		 return bestFeature;
	 }
	 
	 private String majorityClassificationCount(String[] labels) {
		Map<String, Integer> labelCount=new HashMap<String, Integer>();
		for(String s : labels) {
			if(!labelCount.containsKey(s)) labelCount.put(s,0);
			labelCount.put(s,labelCount.get(s)+1);
		}
		int count=-1;
		String t="";
		for(String s : labelCount.keySet()) {
			if(labelCount.get(s)>count) {
				count=labelCount.get(s);
				t=s;
			}
		}
		return t;
	 }
	 
	 public TreeNode creaDecTree(Matrix_2D<String> dataSet,String[] features) {
		int row=dataSet.getRowDimension(),col=dataSet.getColDimension();
		String[] labelList=new String[row];//分类类别
		for(int i=0;i<row;++i) labelList[i]=dataSet.get(i, col-1);
		int num=0;
		for(String str : labelList)
			if(str.equals(labelList[0])) ++num;
		if(num==labelList.length) {
			System.out.println("一类..."+features.length);
			return new TreeNode(labelList[0],null);
		}
		if(col==1) {
			System.out.println("一列"+features.length);
			return new TreeNode(majorityClassificationCount(labelList),null);
		}
		int bestFeature=chooseBestFeatureToSplit(dataSet);
		String bestFeatureLabel=features[bestFeature];
		String[] subFeatures=subArray(features, bestFeatureLabel);
		Set<String> uniqFeatureVals=new HashSet<String>();
		for(int i=0;i<row;++i) uniqFeatureVals.add(dataSet.get(i, bestFeature));
		Map<String, TreeNode> c=new HashMap<String, TreeNode>();
		if(uniqFeatureVals.size()==1) {
			for(String x : uniqFeatureVals) c.put(x, creaDecTree(splitDataSet(dataSet, bestFeature, x+x), subFeatures));
		}
		else {
			for(String x : uniqFeatureVals) c.put(x, creaDecTree(splitDataSet(dataSet, bestFeature, x), subFeatures));
		}
		return new TreeNode(bestFeatureLabel,c);
	 }
	 
	 private String[] subArray(String[] original,String str) {
		String[] subArray=new String[original.length-1];
		int k=0;
		for(String s : original) {
			if(!s.equals(str)) subArray[k++]=s;
		}
		return subArray;
	}
	 
	 public String classification(TreeNode tree,String[] features,ArrayList<String> sample) {
			while(tree!=null&&tree.getChildren()!=null) {
				try {
					/*for(String s : tree.getChildren().keySet()) {
						System.out.print(s+"\t");
					}
					System.out.println("\n"+tree.getElement()+"\t"+sample.get(getIndex(features, (String)tree.getElement())));
					*/
					tree=tree.getChildren().get(sample.get(getIndex(features, (String)tree.getElement())));
				} catch (Exception e) {
					// TODO: handle exception
					e.printStackTrace();
					System.out.println("..............................");
					return "no such classification";
				}
			}
			if(tree==null) {
				
				return "no such classification";
			}
			return tree.getElement();
		}
		private int getIndex(String[] arr,String s) {
			for(int i=0;i<arr.length;++i)
				if(arr[i].equals(s)) return i;
			return-1;
		}
		
		public static void main(String[] args) throws IOException {//AutismAdultDataPlus.txt/
			ID3Tool dthTool=new ID3Tool();//AutismAdultDataPlus.txt/StudentAcademicsPerformance.txt
			Matrix_2D<String> trainingSet=dthTool.readData("AutismAdultDataPlus.txt");
			String[] features=new String[trainingSet.getColDimension()-1];
			for(int i=0;i<features.length;++i)
				features[i]="特征"+String.valueOf(i);
			TreeNode tree=dthTool.creaDecTree(trainingSet, features);
			int num=0;
			final int row=trainingSet.getRowDimension(),col=trainingSet.getColDimension();
			for(int i=0;i<row;++i) {
				String tmp=dthTool.classification(tree, features, trainingSet.get(i));
				if(tmp.equals(trainingSet.get(i).get(col-1))) {
					++num;
				}
				System.out.println(i+"\t"+trainingSet.get(i).get(col-1)+"\t"+tmp);
			}
			System.out.println("分类精度:"+(num/(double)row));
		}
}

工具类:TreeNode.java

package tool;

import java.util.Map;

public class TreeNode {
	private String element;
	private Map<String, TreeNode>  children;
	
	public TreeNode() {
		// TODO 自动生成的构造函数存根
	}
	
	public TreeNode(String e,Map<String, TreeNode>  c) {
		// TODO 自动生成的构造函数存根
		element=e;
		children=c;
	}
	
	public Map<String, TreeNode> getChildren() {
		return children;
	}
	
	public String getElement() {
		return element;
	}
}

工具类:Matrix_2D.java

package tool;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

public class Matrix_2D<T> {
	ArrayList<ArrayList<T>> data;
	
	public Matrix_2D() {
		// TODO 自动生成的构造函数存根
		data=new ArrayList<ArrayList<T>>();
	}
	
	public Matrix_2D(ArrayList<ArrayList<T>> d) {
		data=new ArrayList<ArrayList<T>>();
		for(ArrayList<T> val : d)
			this.putLine(val);
	}

	public void putLine(ArrayList<T> line) {
		ArrayList<T> tmp=new ArrayList<T>();
		for(T t : line) tmp.add(t);
		data.add(tmp);
	}
	
	public int getRowDimension() {
		return data.size();
	}
	
	public int getColDimension() {
		if(getRowDimension()==0) return 0;
		return data.get(0).size();
	}
	
	public ArrayList<T> get(int i) {
		return data.get(i);
	}
	
	public T get(int i,int j) {
		return data.get(i).get(j);
	}
	
	public T remove(int i,int j) {
		return data.get(i).remove(j);
	}
	
	public ArrayList<T> remove(int index) {
		return data.remove(index);
	}
	
	public static String[] subArray(String[] original,String str) {
		String[] subArray=new String[original.length-1];
		int k=0;
		for(String s : original) {
			if(!s.equals(str)) subArray[k++]=s;
		}
		return subArray;
	}
	
	public static ArrayList<String> copyArrayList(ArrayList<String> data) {
		ArrayList<String> d=new ArrayList<String>();
		for(String s : data) d.add(s);
		return d;
	}
	
	public static String majority(ArrayList<String> labels) {
		Map<String, Integer> labelCount=new HashMap<String, Integer>();
		for(String s : labels) {
			if(!labelCount.containsKey(s)) labelCount.put(s,0);
			labelCount.put(s,labelCount.get(s)+1);
		}
		int count=-1;
		String t="";
		for(String s : labelCount.keySet()) {
			if(labelCount.get(s)>count) {
				count=labelCount.get(s);
				t=s;
			}
		}
		return t;
	}
}

工具类:ReadData.java

package tool;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;

public class ReadData {
	public static ArrayList<ArrayList<String>> readDataFile(String path) throws IOException {
		ArrayList<ArrayList<String>> trainingSet=new ArrayList<ArrayList<String>>();
		File file=new File(path);
		if(!file.exists()||!file.isFile()) {
			System.out.println(file.getAbsolutePath());
			return null;
		}
		BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
        String str = "";  
        while ((str=reader.readLine())!=null) {
            String[] tokenizer = str.split(",");  
            ArrayList<String> s = new ArrayList<String>();  
            for(int i=0;i<tokenizer.length;i++){
                s.add(tokenizer[i]);
            }
            trainingSet.add(s);  
        }  
        reader.close();
        //打乱数据集
        for(int i=0;i<trainingSet.size();++i) {
        	int t=(int) ((trainingSet.size()-i)*Math.random());
        	trainingSet.add(trainingSet.remove(t));
        }
		return trainingSet;
	}
}

参考:
决策树分类器-Java实现
决策树分类算法:C4.5算法

发布了13 篇原创文章 · 获赞 0 · 访问量 1412

猜你喜欢

转载自blog.csdn.net/qq_34262612/article/details/104104640