ID3是经典的分类算法,要理解ID3算法,需要先了解一些基本的信息论概念,包括信息量,熵,后验熵,条件熵。ID3算法的核心思想是选择互信息量最大的属性作为分割节点,这样做可以保证所建立的决策树高度最小。
树结构代码:
/** * C4.5决策树数据结构 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 上午10:47:37 * */ public class TreeNode { private String nodeName; // 决策树节点名称 private List<String> splitAttributes; // 分裂属性名 private ArrayList<TreeNode> childrenNodes; // 决策树的子节点 private ArrayList<ArrayList<String>> dataSet; // 划分到该节点的数据集 private ArrayList<String> arrributeSet; // 数据集所有属性 public TreeNode(){ childrenNodes = new ArrayList<TreeNode>(); } public String getNodeName() { return nodeName; } public void setNodeName(String nodeName) { this.nodeName = nodeName; } public List<String> getSplitAttributes() { return splitAttributes; } public void setSplitAttributes(List<String> splitAttributes) { this.splitAttributes = splitAttributes; } public ArrayList<TreeNode> getChildrenNodes() { return childrenNodes; } public void setChildrenNodes(ArrayList<TreeNode> childrenNodes) { this.childrenNodes = childrenNodes; } public ArrayList<ArrayList<String>> getDataSet() { return dataSet; } public void setDataSet(ArrayList<ArrayList<String>> dataSet) { this.dataSet = dataSet; } public ArrayList<String> getArrributeSet() { return arrributeSet; } public void setArrributeSet(ArrayList<String> arrributeSet) { this.arrributeSet = arrributeSet; } }
决策树算法:
/** * 构造决策树的类 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 下午4:42:07 * */ public class DecisionTree { /** * 建树类 * @param dataSet * @param attributeSet * @return */ public TreeNode buildTree(ArrayList<ArrayList<String>> dataSet, ArrayList<String> attributeSet) { TreeNode node = new TreeNode(); node.setDataSet(dataSet); node.setArrributeSet(attributeSet); // 根据当前数据集计算决策树的节点 int index = -1; double gain = 0; double maxGain = 0; for(int i = 0; i < attributeSet.size() - 1; i++) { gain = ComputeUtil.computeEntropy(dataSet, attributeSet.size() - 1) - ComputeUtil.computeConditinalEntropy(dataSet, i); if(gain > maxGain) { index = i; maxGain = gain; } } ArrayList<String> splitAttributes = ComputeUtil.getTypes(dataSet, index); // 获取该节点下的分裂属性 node.setSplitAttributes(splitAttributes); node.setNodeName(attributeSet.get(index)); // 判断每个属性列是否需要继续分裂 for(int i = 0; i < splitAttributes.size(); i++) { ArrayList<ArrayList<String>> splitDataSet = ComputeUtil.getDataSet(dataSet, index, splitAttributes.get(i)); // 判断这个分裂子数据集的目标属性是否纯净,如果纯净则结束,否则继续分裂 int desColumn = splitDataSet.get(0).size() - 1; // 目标属性列所在的列号 ArrayList<String> desAttributes = ComputeUtil.getTypes(splitDataSet, desColumn); TreeNode childNode = new TreeNode(); if(desAttributes.size() == 1) { childNode.setNodeName(desAttributes.get(0)); } else { ArrayList<String> newAttributeSet = new ArrayList<String>(); for(String s : attributeSet) { // 删除新属性集合中已作为决策树节点的属性值 if(!s.equals(attributeSet.get(index))) { newAttributeSet.add(s); } } ArrayList<ArrayList<String>> newDataSet = new ArrayList<ArrayList<String>>(); for(ArrayList<String> data : splitDataSet) { // 除掉columnIndex参数指定的 ArrayList<String> tmp = new ArrayList<String>(); for(int j = 0; j < data.size(); j++) { if(j != index) { tmp.add(data.get(j)); } } newDataSet.add(tmp); } childNode = buildTree(newDataSet, newAttributeSet); // 递归建树 } node.getChildrenNodes().add(childNode); } return node; } /** * 打印建好的树 * @param root */ public void printTree(TreeNode root) { System.out.println("----------------"); if(null != root.getSplitAttributes()) { System.out.print("分裂节点:" + root.getNodeName()); for(String attr : root.getSplitAttributes()) { System.out.print("(" + attr + ") "); } } else { System.out.print("分裂节点:" + root.getNodeName()); } if(null != root.getChildrenNodes()) { for(TreeNode node : root.getChildrenNodes()) { printTree(node); } } } /** * * @Title: searchTree * @Description: 层次遍历树 * @return void * @throws */ public void searchTree(TreeNode root) { Queue<TreeNode> queue = new LinkedList<TreeNode>(); queue.offer(root); while(queue.size() != 0) { TreeNode node = queue.poll(); if(null != node.getSplitAttributes()) { System.out.print("分裂节点:" + node.getNodeName() + "; "); for(String attr : node.getSplitAttributes()) { System.out.print(" (" + attr + ") "); } } else { System.out.print("叶子节点:" + node.getNodeName() + "; "); } if(null != node.getChildrenNodes()) { for(TreeNode nod : node.getChildrenNodes()) { queue.offer(nod); } } } } }
一些util代码:
/** * C4.5算法所需的各类计算方法 * @author zhenhua.chen * @Description: TODO * @date 2013-3-1 上午10:48:47 * */ public class ComputeUtil { /** * 获取指定数据集中指定属性列的各个类别 * @Title: getTypes * @Description: TODO * @return ArrayList<String> * @throws */ public static ArrayList<String> getTypes(ArrayList<ArrayList<String>> dataSet, int columnIndex) { ArrayList<String> list = new ArrayList<String>(); for(ArrayList<String> data : dataSet) { if(!list.contains(data.get(columnIndex))) { list.add(data.get(columnIndex)); } } return list; } /** * 获取指定数据集中指定属性列的各个类别及其计数 * @Title: getClassCounts * @Description: TODO * @return Map<String,Integer> * @throws */ public static Map<String, Integer> getTypeCounts(ArrayList<ArrayList<String>> dataSet, int columnIndex) { Map<String, Integer> map = new HashMap<String, Integer>(); for(ArrayList<String> data : dataSet) { String key = data.get(columnIndex); if(map.containsKey(key)) { map.put(key, map.get(key) + 1); } else { map.put(key, 1); } } return map; } /** * 获取指定列上指定类别的数据集合(分裂后的数据子集) * @Title: getDataSet * @Description: TODO * @return ArrayList<ArrayList<String>> * @throws */ public static ArrayList<ArrayList<String>> getDataSet(ArrayList<ArrayList<String>> dataSet, int columnIndex, String attribueClass) { ArrayList<ArrayList<String>> splitDataSet = new ArrayList<ArrayList<String>>(); for(ArrayList<String> data : dataSet) { if(data.get(columnIndex).equals(attribueClass)) { splitDataSet.add(data); } } return splitDataSet; } /** * 计算指定列(属性)的信息熵 * @Title: computeEntropy * @Description: TODO * @return double * @throws */ public static double computeEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) { Map<String, Integer> map = getTypeCounts(dataSet, columnIndex); int dataSetSize = dataSet.size(); Iterator<String> keyIter = map.keySet().iterator(); double entropy = 0; while(keyIter.hasNext()) { double prob = (double)map.get((String)keyIter.next()) / (double)dataSetSize; entropy += (-1) * prob * Math.log(prob) / Math.log(2); } return entropy; } /** * 计算基于指定属性列对目标属性的条件信息熵 */ public static double computeConditinalEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) { Map<String, Integer> map = getTypeCounts(dataSet, columnIndex); // 获取该属性列的所有列别及其计数 double conditionalEntropy = 0; // 条件熵 // 获取根据每个类别分割后的数据集合 Iterator<String> iter = map.keySet().iterator(); while(iter.hasNext()) { ArrayList<ArrayList<String>> splitDataSet = getDataSet(dataSet, columnIndex, (String)iter.next()); // 计算目标属性列的列索引 int desColumn = 0; if(splitDataSet.get(0).size() > 0) { desColumn = splitDataSet.get(0).size() - 1; } double probY = (double)splitDataSet.size() / (double)dataSet.size(); Map<String, Integer> map1 = getTypeCounts(splitDataSet, desColumn); //根据分割后的子集计算后验熵 Iterator<String> iter1 = map1.keySet().iterator(); double proteriorEntropy = 0; while(iter1.hasNext()) { String key = (String)iter1.next(); // 目标属性列中的一个分类 double posteriorProb = (double)map1.get(key) / (double)splitDataSet.size(); proteriorEntropy += (-1) * posteriorProb * Math.log(posteriorProb) / Math.log(2); } conditionalEntropy += probY * proteriorEntropy; // 基于某个分割属性计算条件熵 } return conditionalEntropy; } }
测试代码:
public class Test { public static void main(String[] args) { File f = new File("D:/test.txt"); BufferedReader reader = null; try { reader = new BufferedReader(new FileReader(f)); String str = null; try { str = reader.readLine(); ArrayList<String> attributeList = new ArrayList<String>(); String[] attributes = str.split("\t"); for(int i = 0; i < attributes.length; i++) { attributeList.add(attributes[i]); } ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>(); while((str = reader.readLine()) != null) { ArrayList<String> tmpList = new ArrayList<String>(); String[] s = str.split("\t"); for(int i = 0; i < s.length; i++) { tmpList.add(s[i]); } dataSet.add(tmpList); } DecisionTree dt = new DecisionTree(); TreeNode root = dt.buildTree(dataSet, attributeList); // dt.printTree(root); dt.searchTree(root); } catch (IOException e) { e.printStackTrace(); } } catch (FileNotFoundException e) { e.printStackTrace(); } } }