未剪枝决策树

未剪枝决策树


public class TreeNode {

    public String aName;//属性
    public String[] attributes;//属性成员
    public String parentAttribute;
    public TreeNode parent;//父节点
    public TreeNode childnext[];//孩子节点
    public String type;//类别
    public boolean isLeaft;//叶子
    TreeNode(){
        aName=null;
        attributes=null;
        parentAttribute=null;
        parent=null;
        childnext=null;
        type=null;
        isLeaft=false;
    }
}
import java.util.ArrayList;
import java.util.Iterator;


public class TreeGenerateFunction {

    @SuppressWarnings("unchecked")
    public void TreeGenerate(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a,TreeNode node){
        ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
        ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();

        if(isDclassOnlyOne(D)){//D中样本全属于同一个类别(好瓜or坏瓜)
            int xLength = D.get(0).size();
            node.type = D.get(0).get(xLength-1).toString();
            node.isLeaft=true;
            return;
        }
        if(A.size()==0||isDxarributeEquelalsA(D)){//A为空或者D中样本在A上的取值一样
            int count = maxClassName(D);
            if(count>(D.size()-count))
                node.type= "是";
            else 
                node.type="否";

            node.isLeaft=true;
            return;
        }
        //划分
        //8:选优
            String wellAttibut = selectAttribute(D,A);
            int V=0;
            int wellIndex=0;
            for(int i=0;i<A.size();i++){
                if(A.get(i).get(0).equals(wellAttibut)){
                     V = A.get(i).size();
                     wellIndex=i;
                }
            }
            //9:
            node.aName=A.get(wellIndex).get(0).toString();
            node.isLeaft=false;
            int len = A.get(wellIndex).size()-1;
            node.childnext = new TreeNode[len];
            node.attributes = new String[len];
            for(int i=1;i<=len;i++){
                node.attributes[i-1]=A.get(wellIndex).get(i);
            }
            for(int v=1;v<=len;v++){
                //生成分支

                TreeNode tree = new TreeNode();
                node.childnext[v-1]=tree;
                tree.parent=node;
                tree.parentAttribute=A.get(wellIndex).get(0);
                tree.isLeaft=false;
               ArrayList<ArrayList<String>> Dv = new ArrayList<ArrayList<String>>();
               Dv = (ArrayList<ArrayList<String>>) getDv(D, A, wellIndex, v).clone();
               if(Dv.size()==0){

                tree.isLeaft=true;
                int count = maxClassName(D);
                if(count>(D.size()-count))
                    tree.type= "是";
                else 
                    tree.type="否";
                return;
               }
               else{
                   int count = maxClassName(D);
                    if(count>(D.size()-count))
                        tree.type= "是";
                    else 
                        tree.type="否";

                   ArrayList<ArrayList<String>> Aa = new ArrayList<ArrayList<String>>();
                   for(int i=0;i<A.size();i++){
                       if(!A.get(wellIndex).get(0).equals(A.get(i).get(0)))
                           Aa.add(A.get(i));
                   }
                   int pp = node.attributes.length;

                   TreeGenerate(Dv, Aa,tree);
               }
            }



    }
    /**
     * 获取最优a*
     * @param D
     * @param A
     * @return
     */
    @SuppressWarnings("unchecked")
    private String selectAttribute(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a) {
        ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
        ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();

        int maxIndex=0;//a*的位置
         double max=0.00;//gain最大值
         for(int i=0;i<A.size();i++){
             double dou = gain(D,A,i);
            System.out.println("gain["+A.get(i).get(0)+"]="+dou+" ");
             if(dou>max){
                 max=dou;
                 maxIndex =i;
             }
         }
        System.out.println("属性:"+A.get(maxIndex).get(0)+" 位置:"+maxIndex);
        return A.get(maxIndex).get(0);
    }
/**
 * 获取gain
 * @param d
 * @param i
 * @return
 */
    @SuppressWarnings("unchecked")
    private double gain(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a, int aIndex) {
        ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
        ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();
        double gainnum=0.00;
        double ent=0.00;
        double wEntDv=0.00;
        int V = A.get(aIndex).size();
        ent = getEnt(D);
        for(int v=1;v<V;v++){
            ArrayList<ArrayList<String>> Dv=(ArrayList<ArrayList<String>>) getDv(D, A, aIndex, v).clone();
            if(Dv.size()==0){
                wEntDv +=0.0;
            }else
            {
            wEntDv +=(((double)Dv.size()/(double)D.size())*getEnt(Dv));
            }

        }
        gainnum = ent-wEntDv;
        return gainnum;
    }
    /**
     * 获取Dv
     * @param D
     * @param aIndex
     * @return
     */
    @SuppressWarnings("unchecked")
    private ArrayList<ArrayList<String>> getDv(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a,int aIndex,int avIndex){
        ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
        ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();

        ArrayList<ArrayList<String>> Dv = new ArrayList<ArrayList<String>>();
        Iterator it = D.iterator();

        int aPosition = 0;
        if(A.get(aIndex).get(0).equals("色泽"))
            aPosition=0;
        else if(A.get(aIndex).get(0).equals("根蒂"))
            aPosition=1;
        else if(A.get(aIndex).get(0).equals("敲声"))
            aPosition=2;
        else if(A.get(aIndex).get(0).equals("纹理"))
            aPosition=3;
        else if(A.get(aIndex).get(0).equals("脐部"))
            aPosition=4;
        else
            aPosition=5;

        int index =0;
        while(it.hasNext()){
            if(index>=D.size()){break;}

                if(A.get(aIndex).get(avIndex).equals(D.get(index).get(aPosition))){
                    Dv.add(D.get(index));
                }
            index++;
        }

        return Dv;
    }
    /**
     * ENT的获取
     * @param D
     * @return
     */
  @SuppressWarnings("unchecked")
private double getEnt(ArrayList<ArrayList<String>> d){
      ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();

      double ent=0.00;
      double p1=0.00;
      double p2=0.00;
      double log1=0.00;
      double log2=0.00;
      int goodNum = maxClassName(D);
      int DNum = D.size();
      p1=(double)goodNum/DNum;
      p2=(double)(DNum-goodNum)/DNum;
      if(p1==0.0||p2==0.0){
          ent = 0;
      }else{
     log1=(double)Math.log(p1)/(Math.log(2));
     log2=(double)Math.log(p2)/(Math.log(2));
      ent=-(p1*log1+p2*log2);
      }

      return ent;
  }


    /**
     * 返回好瓜个数
     * @param D
     * @return String
     */
    @SuppressWarnings("unchecked")
    public int maxClassName(ArrayList<ArrayList<String>> d){

            ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();

        int count=0;//好瓜个数
        int xLength =0;

        xLength=D.get(0).size();//样本大小

        Iterator it = D.iterator();
        int index =0;
        while(it.hasNext()){

            if(index >= D.size()){break;}
            if(D.get(index).get(xLength-1).toString().equals("是"))
                count++;
            index++;
        }
        return count;
    }
    /**
     * 判断D中样本在A上的取值是否都一样
     * @param D
     * @return
     */
    @SuppressWarnings("unchecked")
    public boolean  isDxarributeEquelalsA(ArrayList<ArrayList<String>> d){
        ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();

        boolean isTrue=false;
        Iterator it = D.iterator();
        int xLength = D.get(0).size();//样本长度
        int index = 0;//为样本循环计数
        Object[] astr = D.get(0).toArray();
        while(it.hasNext()){
            if(index>=D.size()){break;}
            Iterator it1 = D.get(index).iterator();
            int index1=0;
            while(it1.hasNext()){
                if(index1>=(D.get(index).size())){break;}
                if(!D.get(index).get(index1).equals(D.get(0).get(index1)))
                {
                    return false;
                }
                index1++;
            }
            index++;
            }
        return true;

    }
    /**
     * 判断样本是否都属于同一个类别
     * @param trainset
     * @return boolean
     */
    public boolean isDclassOnlyOne(ArrayList<ArrayList<String>> D){
        ArrayList<ArrayList<String>> trainset=(ArrayList<ArrayList<String>>) D.clone();

        Iterator it = trainset.iterator();
        int xLength = trainset.get(0).size();//样本长度
        int index = 0;//为样本循环计数
        String vertiftAttrib = trainset.get(index).get(xLength-1).toString();
        while(it.hasNext()){
            if(index>=trainset.size()){break;}
            if(!vertiftAttrib.equals(trainset.get(index).get(xLength-1).toString())){
                return false;
            }
            index++;
        }
        return true;

    }

}


import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;

import org.apache.poi.hssf.usermodel.HSSFCell;
import org.apache.poi.hssf.usermodel.HSSFRow;
import org.apache.poi.hssf.usermodel.HSSFSheet;
import org.apache.poi.hssf.usermodel.HSSFWorkbook;

@SuppressWarnings("restriction")
public class ListTrainSet {
/**
 * xls工作簿的读取
 * @param FileInputStream fis
 * @return ArrayList<ArrayList<String>>
 * @throws IOException
 */
    public ArrayList<ArrayList<String>> readExcle(FileInputStream fis) throws IOException{

        ArrayList<ArrayList<String>> aas = new ArrayList<ArrayList<String>>();
        HSSFWorkbook hwb = new HSSFWorkbook(fis);//创建工作簿对象

        //获取每一个工作簿
        for(int numSheet=0;numSheet<hwb.getNumberOfSheets();numSheet++){

            HSSFSheet hssfSheet = hwb.getSheetAt(numSheet);
            if(hssfSheet == null){
                continue;
            }
            //获取当前工作簿的每一行
            for(int rowNum = 0;rowNum<=hssfSheet.getLastRowNum();rowNum++){

                HSSFRow hssfRow = hssfSheet.getRow(rowNum);//创建行对象
                ArrayList<String> as =new ArrayList<String>();
                if(hssfRow!=null){
                    int cellnum =0;
                    while(hssfRow.getCell(cellnum)!=null){//获取每行中的信息(即每个细胞单元)
                        HSSFCell one = hssfRow.getCell(cellnum);//创建细胞对象
                        as.add(getValue(one).toString());
                        cellnum++;
                    }

                    aas.add(as);//添加到ArrayList<ArrayList<String>>
                }
            }

        }
        return aas;
    }
    /**
     * 转换数据格式
     * @param HSSFCell hssfCell
     * @return String
     */
    @SuppressWarnings({ "static-access" })
    public  String getValue(HSSFCell hssfCell){
        if(hssfCell.getCellType()==hssfCell.CELL_TYPE_BOOLEAN)
            return String.valueOf(hssfCell.getBooleanCellValue());
        else if(hssfCell.getCellType()==hssfCell.CELL_TYPE_NUMERIC)
            return String.valueOf(hssfCell.getNumericCellValue());
        else{
            return String.valueOf(hssfCell.getStringCellValue());
        }

    }
}
import java.util.LinkedList;

/**
 * 层次遍历
 * @author FES
 *
 */
public class LevelOrder {
    public void levelIterator(TreeNode root){
        if(root == null){
            return;
        }
        LinkedList<TreeNode> queue = new LinkedList<TreeNode>();
        TreeNode current = null;
        queue.offer(root);
        while(!queue.isEmpty()){
            current = queue.poll();
            if(current.isLeaft==false){
                System.out.println(current.aName);
            }else{
                System.out.println(current.type);
            }
            if(current.childnext!=null){
              int n = current.childnext.length;
              for(int i=0;i<n;i++){
                  queue.offer(current.childnext[i]);
              }
            }


        }
    }

}
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;


public class Test {

    public static void main(String[] args) throws IOException {
        // TODO Auto-generated method stub
        ListTrainSet ls = new ListTrainSet();
        File f1 = new File("E:\\text\\D.xls");
        File f2 = new File("E:\\text\\A.xls");
        File f3 = new File("E:\\text\\T.xls");
        FileInputStream fin1 = new FileInputStream(f1);
        FileInputStream fin2 = new FileInputStream(f2);
        FileInputStream fin3 = new FileInputStream(f3);
        ArrayList<ArrayList<String>> D = ls.readExcle(fin1);
        ArrayList<ArrayList<String>> A = ls.readExcle(fin2);
        ArrayList<ArrayList<String>> T = ls.readExcle(fin3);
        System.out.println("D的大小:"+D.size());
        System.out.println("A的大小:"+A.size());

        TreeNode td = new TreeNode();
        TreeGenerateFunction tgf = new TreeGenerateFunction();
        tgf.TreeGenerate(D, A,td);
        new LevelOrder().levelIterator(td);


    }

}

“`
这里写图片描述
这里写图片描述


参考文献

  1. 周志华, 杨强. 机器学习及其应用[M]. 清华大学出版社, 2011.

猜你喜欢

转载自blog.csdn.net/u014439289/article/details/71760107