java实现ID3算法

ID3是经典的分类算法,要理解ID3算法,需要先了解一些基本的信息论概念,包括信息量,熵,后验熵,条件熵。ID3算法的核心思想是选择互信息量最大的属性作为分割节点,这样做可以保证所建立的决策树高度最小。

树结构代码:

/**
 * C4.5决策树数据结构
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-1 上午10:47:37 
 *
 */
public class TreeNode {
	private String nodeName; // 决策树节点名称
	private List<String> splitAttributes; // 分裂属性名
	private ArrayList<TreeNode> childrenNodes; // 决策树的子节点
	private ArrayList<ArrayList<String>> dataSet; // 划分到该节点的数据集 
	private ArrayList<String> arrributeSet; // 数据集所有属性
	
	public TreeNode(){
		childrenNodes = new ArrayList<TreeNode>();
	}
	
	public String getNodeName() {
		return nodeName;
	}
	public void setNodeName(String nodeName) {
		this.nodeName = nodeName;
	}
	public List<String> getSplitAttributes() {
		return splitAttributes;
	}
	public void setSplitAttributes(List<String> splitAttributes) {
		this.splitAttributes = splitAttributes;
	}
	public ArrayList<TreeNode> getChildrenNodes() {
		return childrenNodes;
	}
	public void setChildrenNodes(ArrayList<TreeNode> childrenNodes) {
		this.childrenNodes = childrenNodes;
	}
	public ArrayList<ArrayList<String>> getDataSet() {
		return dataSet;
	}
	public void setDataSet(ArrayList<ArrayList<String>> dataSet) {
		this.dataSet = dataSet;
	}
	public ArrayList<String> getArrributeSet() {
		return arrributeSet;
	}
	public void setArrributeSet(ArrayList<String> arrributeSet) {
		this.arrributeSet = arrributeSet;
	}
}

决策树算法:

/**
 * 构造决策树的类
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-1 下午4:42:07 
 *
 */
public class DecisionTree {
	/**
	 * 建树类
	 * @param dataSet
	 * @param attributeSet
	 * @return
	 */
	public TreeNode buildTree(ArrayList<ArrayList<String>> dataSet, ArrayList<String> attributeSet) {
		TreeNode node = new TreeNode();
		node.setDataSet(dataSet);
		node.setArrributeSet(attributeSet);
		
		// 根据当前数据集计算决策树的节点
		int index = -1;
		double gain = 0;
		double maxGain = 0;
		for(int i = 0; i < attributeSet.size() - 1; i++) {
			gain = ComputeUtil.computeEntropy(dataSet, attributeSet.size() - 1) - ComputeUtil.computeConditinalEntropy(dataSet, i);
			if(gain > maxGain) {
				index = i;
				maxGain = gain;
			}
		}
		ArrayList<String> splitAttributes = ComputeUtil.getTypes(dataSet, index); // 获取该节点下的分裂属性
		node.setSplitAttributes(splitAttributes);
		node.setNodeName(attributeSet.get(index));
		
		// 判断每个属性列是否需要继续分裂
		for(int i = 0; i < splitAttributes.size(); i++) {
			ArrayList<ArrayList<String>> splitDataSet = ComputeUtil.getDataSet(dataSet, index, splitAttributes.get(i));
			
			// 判断这个分裂子数据集的目标属性是否纯净,如果纯净则结束,否则继续分裂
			int desColumn = splitDataSet.get(0).size() - 1; // 目标属性列所在的列号
			ArrayList<String> desAttributes = ComputeUtil.getTypes(splitDataSet, desColumn);
			TreeNode childNode = new TreeNode();
			if(desAttributes.size() == 1) {
				childNode.setNodeName(desAttributes.get(0));
			} else {
				ArrayList<String> newAttributeSet = new ArrayList<String>();
				for(String s : attributeSet) { // 删除新属性集合中已作为决策树节点的属性值
					if(!s.equals(attributeSet.get(index))) {
						newAttributeSet.add(s);
					}
				}
				
				ArrayList<ArrayList<String>> newDataSet = new ArrayList<ArrayList<String>>();
				for(ArrayList<String> data : splitDataSet) { // 除掉columnIndex参数指定的
					ArrayList<String> tmp = new ArrayList<String>();
					for(int j = 0; j < data.size(); j++) {
						if(j != index) {
							tmp.add(data.get(j));
						}
					}
					newDataSet.add(tmp);
				}
				
				childNode = buildTree(newDataSet, newAttributeSet); // 递归建树
			}
			node.getChildrenNodes().add(childNode);
		}
		return node;
	}
	
	/**
	 * 打印建好的树
	 * @param root
	 */
	public void printTree(TreeNode root) {
		System.out.println("----------------");
		if(null != root.getSplitAttributes()) {
			System.out.print("分裂节点:" + root.getNodeName());
			for(String attr : root.getSplitAttributes()) {
				System.out.print("(" + attr + ") ");
			}
		} else {
			System.out.print("分裂节点:" + root.getNodeName());
		}
		
		if(null != root.getChildrenNodes()) {
			for(TreeNode node : root.getChildrenNodes()) {
				printTree(node);
			}
		}
		
	}
	
	/**
	 * 
	* @Title: searchTree 
	* @Description: 层次遍历树
	* @return void
	* @throws
	 */
	public void searchTree(TreeNode root) {
		Queue<TreeNode> queue = new LinkedList<TreeNode>();
		queue.offer(root);
		
		while(queue.size() != 0) {
			TreeNode node = queue.poll();
			if(null != node.getSplitAttributes()) {
				System.out.print("分裂节点:" + node.getNodeName() + "; "); 
				for(String attr : node.getSplitAttributes()) {
					System.out.print(" (" + attr + ") ");
				}
			} else {
				System.out.print("叶子节点:" + node.getNodeName() + "; "); 
			}
			
			if(null != node.getChildrenNodes()) {
				for(TreeNode nod : node.getChildrenNodes()) {
					queue.offer(nod);
				}
			}
		}
	}
	
}

一些util代码:

/**
 * C4.5算法所需的各类计算方法
 * @author zhenhua.chen
 * @Description: TODO
 * @date 2013-3-1 上午10:48:47 
 *
 */
public class ComputeUtil {
	
	/**
	 * 获取指定数据集中指定属性列的各个类别
	* @Title: getTypes 
	* @Description: TODO
	* @return ArrayList<String>
	* @throws
	 */
	public static ArrayList<String> getTypes(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		ArrayList<String> list = new ArrayList<String>();
		for(ArrayList<String> data : dataSet) {
			if(!list.contains(data.get(columnIndex))) {
				list.add(data.get(columnIndex));
			}
		}
		return list;
	}
	
	/**
	 * 获取指定数据集中指定属性列的各个类别及其计数
	* @Title: getClassCounts 
	* @Description: TODO
	* @return Map<String,Integer>
	* @throws
	 */
	public static Map<String, Integer> getTypeCounts(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		Map<String, Integer> map = new HashMap<String, Integer>();
		for(ArrayList<String> data : dataSet) {
			String key = data.get(columnIndex);
			if(map.containsKey(key)) {
				map.put(key, map.get(key) + 1);
			} else {
				map.put(key, 1);
			}
		}
		return map;
	}
	
	/**
	 * 获取指定列上指定类别的数据集合(分裂后的数据子集)
	* @Title: getDataSet 
	* @Description: TODO
	* @return ArrayList<ArrayList<String>>
	* @throws
	 */
	public static ArrayList<ArrayList<String>> getDataSet(ArrayList<ArrayList<String>> dataSet, int columnIndex, String attribueClass) {
		ArrayList<ArrayList<String>> splitDataSet = new ArrayList<ArrayList<String>>();
		for(ArrayList<String> data : dataSet) {
			if(data.get(columnIndex).equals(attribueClass)) {
				splitDataSet.add(data);
			}
		}
		
		return splitDataSet;
	}
	
	/**
	 * 计算指定列(属性)的信息熵
	* @Title: computeEntropy 
	* @Description: TODO
	* @return double
	* @throws
	 */
	public static double computeEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);
		int dataSetSize = dataSet.size();
		Iterator<String> keyIter = map.keySet().iterator();
		double entropy = 0;
		while(keyIter.hasNext()) {
			double prob = (double)map.get((String)keyIter.next()) / (double)dataSetSize;
			entropy += (-1) * prob * Math.log(prob) / Math.log(2); 
			
		}
		return entropy;
	}
	
	/**
	 * 计算基于指定属性列对目标属性的条件信息熵
	 */
	public static double computeConditinalEntropy(ArrayList<ArrayList<String>> dataSet, int columnIndex) {
		Map<String, Integer> map = getTypeCounts(dataSet, columnIndex);  // 获取该属性列的所有列别及其计数
		
		double conditionalEntropy = 0; // 条件熵
		
		// 获取根据每个类别分割后的数据集合
		Iterator<String> iter = map.keySet().iterator(); 
		while(iter.hasNext()) {
			ArrayList<ArrayList<String>> splitDataSet = getDataSet(dataSet, columnIndex, (String)iter.next());
			// 计算目标属性列的列索引
			int desColumn = 0;
			if(splitDataSet.get(0).size() > 0) {
				desColumn = splitDataSet.get(0).size() - 1;
			}
			
			double probY = (double)splitDataSet.size() / (double)dataSet.size();
			
			Map<String, Integer> map1 = getTypeCounts(splitDataSet, desColumn); //根据分割后的子集计算后验熵
			Iterator<String> iter1 = map1.keySet().iterator();
			double proteriorEntropy = 0;
			while(iter1.hasNext()) {
				String key = (String)iter1.next(); // 目标属性列中的一个分类
				double posteriorProb = (double)map1.get(key) / (double)splitDataSet.size();
				proteriorEntropy += (-1) * posteriorProb * Math.log(posteriorProb) / Math.log(2);
			}
			
			conditionalEntropy += probY * proteriorEntropy; // 基于某个分割属性计算条件熵
		}
		return conditionalEntropy;
	}
}

 测试代码:

public class Test {
	public static void main(String[] args) {
		File f = new File("D:/test.txt");
		BufferedReader reader = null;
		
		try {
			reader = new BufferedReader(new FileReader(f));
			String str = null;
			try {
				str = reader.readLine(); 
				ArrayList<String> attributeList = new ArrayList<String>();
				String[] attributes = str.split("\t");
				
				for(int i = 0; i < attributes.length; i++) {
					attributeList.add(attributes[i]);
				}
				
				ArrayList<ArrayList<String>> dataSet = new ArrayList<ArrayList<String>>();
				while((str = reader.readLine()) != null) {
					ArrayList<String> tmpList = new ArrayList<String>();
					String[] s = str.split("\t");
					for(int i = 0; i < s.length; i++) {
						tmpList.add(s[i]);
					}
					dataSet.add(tmpList);
				}
				
				DecisionTree dt = new DecisionTree();
				TreeNode root = dt.buildTree(dataSet, attributeList);
//				dt.printTree(root);
				dt.searchTree(root);
				
			} catch (IOException e) {
				e.printStackTrace();
			}
			
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
	}
}

猜你喜欢

转载自czhsuccess.iteye.com/blog/1864652