版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/l1832876815/article/details/89469015
1. 算法原理
C4.5算法: 首先根据训练集求出各属性的信息熵info, 然后求出类别信息商infod, infod - info[i]得到每个属性的信息增益gain, 然后计算每个属性的信息分裂度h, gain[i] / h[i]得到属性信息增益率。递归选择信息增益率最高的属性,按照该属性对数据集进行分裂,判断分裂之后的数据集类别是否为’纯’的,如果是则将当前分裂属性作为叶节点,如果不是继续递归进行分裂过程。最终训练出一颗决策树。测试过程即根据各属性的值遍历决策树,直到到达叶节点,叶节点的类别即为该测试样例的类别。
2. 代码实现
Node.java
package com.clxk1997;
/**
* @Description 决策树节点
* @Author Clxk
* @Date 2019/4/22 14:16
* @Version 1.0
*/
public class Node {
private String field;
private String value;
public Node() {
}
public Node(String field, String value) {
this.field = field;
this.value = value;
}
public String getField() {
return field;
}
public void setField(String field) {
this.field = field;
}
public String getValue() {
return value;
}
public void setValue(String value) {
this.value = value;
}
}
DecisionTree.java
package com.clxk1997;
import java.util.ArrayList;
/**
* @Description 决策树
* @Author Clxk
* @Date 2019/4/22 13:57
* @Version 1.0
*/
public class DecisionTree {
private Node node;
private ArrayList<DecisionTree> childs;
public Node getNode() {
return node;
}
public void setNode(Node node) {
this.node = node;
}
public ArrayList<DecisionTree> getChilds() {
return childs;
}
public void setChilds(ArrayList<DecisionTree> childs) {
this.childs = childs;
}
public static DecisionTree init() {
DecisionTree decisionTree = new DecisionTree();
decisionTree.setNode(new Node("root", "root"));
decisionTree.setChilds(new ArrayList<DecisionTree>());
return decisionTree;
}
}
C45.java
package com.clxk1997;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.util.*;
/**
* C4.5决策树算法
* @author clxk
*
*/
public class C45 {
private static final int MAXN = 0x3f;
/**
* 属性集合
*/
private static List<String> fields = new ArrayList<>();
private static List<String> curfields = new ArrayList<>();
/**
* 类别集合
*/
private static List<String> classfields = new ArrayList<>();
/**
* 训练集
*/
private static List<String>[] trains = new ArrayList[MAXN];
/**
* 测试集
*/
private static List<String>[] tests = new ArrayList[MAXN];
/**
* 训练集数量
*/
private static int count_train;
/**
* 测试集数量
*/
private static int count_test;
/**
* 决策树初始化
*/
private static DecisionTree tree = DecisionTree.init();
static String ans = null;
/**
* 数据集输入
*/
public static void input() throws Exception{
System.out.println("请输入属性集合的数量: ");
Scanner scan = new Scanner(new FileInputStream(new File("lib/input_train.txt")));
int cnt = scan.nextInt();
System.out.println(cnt);
System.out.println("请输入" + cnt + "个属性: ");
String str;
scan.nextLine();
for(int i = 0; i < cnt; i++) {
str = scan.nextLine();
System.out.println(str);
fields.add(str);
curfields.add(str);
}
System.out.println("请输入类别集合的数量: ");
cnt = scan.nextInt();
scan.nextLine();
System.out.println(cnt);
System.out.println("请输入" + cnt + "个类别标签: ");
for(int i = 0; i < cnt; i++) {
str = scan.nextLine();
System.out.println(str);
classfields.add(str);
}
System.out.println("请输入训练集的数量: ");
cnt = scan.nextInt();
count_train = cnt;
System.out.println(cnt);
scan.nextLine();
System.out.println("请输入训练集: ");
for(int i = 0; i < cnt; i++) {
trains[i] = new ArrayList<>();
str = scan.nextLine();
System.out.println(str);
String[] split = str.split("\\s");
for(String s: split) {
trains[i].add(s);
}
}
}
/**
* 计算信息熵
* @return
*/
public static double[] getInfos(List<String>[] trains, List<String> fields, int count_train) {
double[] infos = new double[MAXN];
/**
* 获取属性信息熵
*/
for(int i = 0; i < fields.size(); i++) {
infos[i] = getInfo(i, false, trains, count_train);
}
for(int i = 0; i < classfields.size(); i++) {
infos[i+fields.size()] = getInfo(i, true, trains, count_train);
}
return infos;
}
/**
* 训练
*/
public static void training() throws Exception{
/**
* 训练集输入
*/
input();
/**
* 递归构建树
*/
train(curfields, trains, count_train, tree);
/**
* 输出树
*/
showTree(tree);
/**
* 测试集输入
*/
inputTest();
}
/**
* 测试集输入
*/
public static void inputTest() throws FileNotFoundException {
System.out.println("ce" + fields.size());
System.out.println("请输入测试集的个数: ");
Scanner scanner = new Scanner(new FileInputStream(new File("lib/input_test.txt")));
count_test = scanner.nextInt();
String str;
scanner.nextLine();
Map<String, String> mp = new HashMap<>();
for(int i = 0; i < count_test; i++) {
mp.clear();
System.out.println("请输入第" + i + "个测试样例: ");
str = scanner.nextLine();
String[] split = str.split("\\s");
for (int j = 0; j < fields.size(); j++) {
mp.put(fields.get(j), split[j].trim());
}
String ans = getAns(mp);
System.out.println("在该测试样例情况下," + classfields.get(0) + " 应为 " + ans);
}
}
/**
* 返回样例的类别值
* @param mp
* @return
*/
public static String getAns(Map<String, String> mp) {
dfsTree(tree, mp);
return ans;
}
/**
* 遍历决策树,寻找样例对应的叶节点
*
* @param tree
* @param mp
* @return
*/
public static void dfsTree(DecisionTree tree, Map<String, String> mp) {
if(tree.getNode().getField().equals("叶节点")){
ans = tree.getNode().getValue();
return;
} else if(tree.getChilds().size() == 1 && tree.getChilds().get(0).getNode().getField().equals("叶节点")) {
ans = tree.getChilds().get(0).getNode().getValue();
return;
} else {
for(int i = 0; i < tree.getChilds().size(); i++) {
String field = tree.getChilds().get(i).getNode().getField();
String value = tree.getChilds().get(i).getNode().getValue();
if(mp.containsKey(field) && mp.get(field).equals(value)) {
dfsTree(tree.getChilds().get(i),mp);
}
}
}
}
/**
*
* @param tree
*/
public static void showTree(DecisionTree tree) {
System.out.print("当前节点属性为: " + tree.getNode().getField() + " 当前属性值为: " +
tree.getNode().getValue());
System.out.println();
if(tree.getChilds() == null) return;
for(int i = 0; i < tree.getChilds().size(); i++) {
System.out.println("子节点属性为: " + tree.getChilds().get(i).getNode().getField()
+ " 子节点属性值为: " + tree.getChilds().get(i).getNode().getValue());
}
for(int i = 0; i < tree.getChilds().size(); i++) {
showTree(tree.getChilds().get(i));
}
}
/**
* 递归构建树
* @param fields
* @param trains
* @param count_train
* @param tree
*/
public static void train(List<String> fields, List<String>[] trains, int count_train, DecisionTree tree) {
if(tree.getChilds() == null) {
tree.setChilds(new ArrayList<>());
}
if(fields.size() == 0) return;
if(isPure(trains)) {
DecisionTree decisionTree = new DecisionTree();
decisionTree.setNode(new Node("叶节点", trains[0].get(trains[0].size()-1)));
tree.getChilds().add(decisionTree);
return;
}
/**
* 计算信息熵
*/
System.out.println("计算得到所有属性的信息熵为: ");
double[] infos = getInfos(trains, fields, count_train);
for(int i = 0; i < fields.size(); i++) {
System.out.format("%s : %.3f\n", fields.get(i), infos[i]);
}
System.out.println("计算得到类别D的信息熵为: ");
for(int i = 0; i < classfields.size(); i++) {
System.out.format("%s : %.3f\n", classfields.get(i), infos[i+ fields.size()]);
}
/**
* 计算信息增益
*/
System.out.println("计算得到所有属性的信息增益为: ");
double[] gains = getGains(infos, infos[fields.size()], fields);
for(int i = 0; i < fields.size(); i++) {
System.out.format("%s : %.3f\n", fields.get(i), gains[i]);
}
/**
* 计算属性分裂信息度量
*/
System.out.println("计算得到所有属性的分裂信息度量为: ");
double[] h = new double[MAXN];
for(int i = 0; i < fields.size(); i++) {
h[i] = getH(i, trains, count_train);
}
for(int i = 0; i < fields.size(); i++) {
System.out.format("%s : %.3f\n", fields.get(i), h[i]);
}
/**
* 计算信息增益率
*/
System.out.println("计算得到所有属性的信息增益率为");
double[] igr = new double[MAXN];
igr = getIGR(gains, h, fields);
for(int i = 0; i < fields.size(); i++) {
System.out.format("%s : %.3f\n", fields.get(i), igr[i]);
}
/**
* 找到分裂属性
*/
int index = 0;
double maxd = 0.0;
for (int i = 0; i < fields.size(); i++) {
if(igr[i] > maxd) {
maxd = igr[i];
index = i;
}
}
String field = fields.get(index);
System.out.println("分裂属性为: " + field);
fields.remove(index);
Map<String, ArrayList<Integer>> ids = new HashMap<>();
for(int i = 0; i < trains.length; i++) {
if(trains[i] == null) break;
List<String> train = trains[i];
if(!ids.containsKey(train.get(index))) {
ArrayList<Integer> integers = new ArrayList<>();
integers.add(i);
ids.put(train.get(index), integers);
}else {
ids.get(train.get(index)).add(i);
}
}
for(Map.Entry<String, ArrayList<Integer>>entry : ids.entrySet()) {
ArrayList<String> array[] = new ArrayList[MAXN];
ArrayList<Integer> value = entry.getValue();
for(int i = 0; i < value.size(); i++) {
int cur = value.get(i);
trains[cur].remove(index);
array[i] = (ArrayList<String>) trains[cur];
}
DecisionTree decisionTree = new DecisionTree();
decisionTree.setNode(new Node(field, entry.getKey()));
tree.getChilds().add(decisionTree);
train(fields, array, value.size(), decisionTree);
}
}
/**
* 判断当前节点是不是纯节点
* @param trains
* @return
*/
public static boolean isPure(List<String>[] trains) {
Set<String> set = new HashSet<>();
for(int i = 0; i < trains.length; i++) {
if(trains[i] == null) break;
set.add(trains[i].get(trains[i].size()-1));
if(set.size() > 1) return false;
}
return true;
}
/**
* 计算属性信息熵
*/
public static double getInfo(int index, boolean isClass, List<String>[] trains, int count_train) {
double ans = 0;
Map<String, Integer> fi = new HashMap<>();//field->cnt
Map<String, Map<String, Integer>> cl = new HashMap<>();//field->class
if(isClass) {
for(int i = 0; i < trains.length; i++) {
if(trains[i] == null) break;
ArrayList<String> arr = (ArrayList<String>) trains[i];
String field = arr.get(arr.size() - 1);
if(fi.containsKey(field)) {
fi.put(field, fi.get(field) + 1);
} else {
fi.put(field, 1);
}
}
for(Map.Entry<String, Integer>entry : fi.entrySet()) {
double div = (double)entry.getValue() / (double)count_train;
ans -= div * (Math.log(div)/Math.log((double)2));
}
} else {
for(int i = 0; i < trains.length; i++) {
int t = 0;
Map<String, Integer> curmap = new HashMap<>();
if(trains[i] == null) break;
ArrayList<String> arr = (ArrayList<String>) trains[i];
if(fi.containsKey(arr.get(index))) {
t = fi.get(arr.get(index));
fi.put(arr.get(index), ++t);
curmap = cl.get(arr.get(index));
if(curmap.containsKey(arr.get(arr.size()-1))) {
t = (int) curmap.get(arr.get(arr.size()-1));
curmap.put(arr.get(arr.size()-1), t+1);
} else {
curmap.put(arr.get(arr.size()-1), 1);
}
cl.put(arr.get(index), curmap);
} else {
fi.put(arr.get(index), 1);
curmap.put(arr.get(arr.size() - 1), 1);
cl.put(arr.get(index), curmap);
}
}
for(Map.Entry<String, Integer>entry : fi.entrySet()) {
double curans = 0;
String fie = entry.getKey();
Map<String, Integer> curmap = cl.get(fie);
for(Map.Entry<String, Integer>en : curmap.entrySet()) {
double div = (double)en.getValue() / (double)entry.getValue();
curans -= div * (Math.log(div)/Math.log((double)2));
}
curans *= (double)entry.getValue() / (double)count_train;
ans += curans;
}
}
return ans;
}
/**
* 计算属性信息增益
* @return
* @param infos
* @param d
*/
public static double[] getGains(double[] infos, double d, List<String> fields) {
double[] gain = new double[MAXN];
for(int i = 0; i <fields.size(); i++) {
gain[i] = d - infos[i];
}
return gain;
}
/**
* 计算属性分裂信息度量
* @param index
* @return
*/
public static double getH(int index, List<String>[] trains, int count_train) {
double ans = 0;
Map<String, Integer> fi = new HashMap<>();
for(List<String> array : trains) {
if(array == null) break;
String name = array.get(index);
if(fi.containsKey(name)) {
fi.put(name, fi.get(name)+1);
} else {
fi.put(name, 1);
}
}
for(Map.Entry<String, Integer>entry : fi.entrySet()) {
double div = (double)entry.getValue() / (double)count_train;
ans -= div * (Math.log(div) / Math.log((double)2));
}
return ans;
}
/**
* 计算信息增益率
* @param gains
* @param h
* @return
*/
public static double[] getIGR(double[] gains, double[] h, List<String> fields) {
double[] ans = new double[MAXN];
for(int i = 0; i < fields.size(); i++) {
ans[i] = gains[i] / h[i];
}
return ans;
}
public static void main(String[] args) {
try {
training();
} catch (Exception e) {
e.printStackTrace();
}
}
}
input_train.txt
4
天气
温度
湿度
是否有风
1
是否适合打网球
10
晴 热 高 否 否
晴 热 高 是 否
阴 热 高 否 是
雨 温 高 否 是
雨 凉爽 中 否 是
雨 凉爽 中 是 否
阴 凉爽 中 是 是
晴 温 高 否 否
晴 凉爽 中 否 是
雨 温 中 否 是
input_test.txt
4
晴 温 中 是
阴 温 高 是
阴 热 中 否
雨 温 高 是