未剪枝决策树
public class TreeNode {
public String aName;//属性
public String[] attributes;//属性成员
public String parentAttribute;
public TreeNode parent;//父节点
public TreeNode childnext[];//孩子节点
public String type;//类别
public boolean isLeaft;//叶子
TreeNode(){
aName=null;
attributes=null;
parentAttribute=null;
parent=null;
childnext=null;
type=null;
isLeaft=false;
}
}
import java.util.ArrayList;
import java.util.Iterator;
public class TreeGenerateFunction {
@SuppressWarnings("unchecked")
public void TreeGenerate(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a,TreeNode node){
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();
if(isDclassOnlyOne(D)){//D中样本全属于同一个类别(好瓜or坏瓜)
int xLength = D.get(0).size();
node.type = D.get(0).get(xLength-1).toString();
node.isLeaft=true;
return;
}
if(A.size()==0||isDxarributeEquelalsA(D)){//A为空或者D中样本在A上的取值一样
int count = maxClassName(D);
if(count>(D.size()-count))
node.type= "是";
else
node.type="否";
node.isLeaft=true;
return;
}
//划分
//8:选优
String wellAttibut = selectAttribute(D,A);
int V=0;
int wellIndex=0;
for(int i=0;i<A.size();i++){
if(A.get(i).get(0).equals(wellAttibut)){
V = A.get(i).size();
wellIndex=i;
}
}
//9:
node.aName=A.get(wellIndex).get(0).toString();
node.isLeaft=false;
int len = A.get(wellIndex).size()-1;
node.childnext = new TreeNode[len];
node.attributes = new String[len];
for(int i=1;i<=len;i++){
node.attributes[i-1]=A.get(wellIndex).get(i);
}
for(int v=1;v<=len;v++){
//生成分支
TreeNode tree = new TreeNode();
node.childnext[v-1]=tree;
tree.parent=node;
tree.parentAttribute=A.get(wellIndex).get(0);
tree.isLeaft=false;
ArrayList<ArrayList<String>> Dv = new ArrayList<ArrayList<String>>();
Dv = (ArrayList<ArrayList<String>>) getDv(D, A, wellIndex, v).clone();
if(Dv.size()==0){
tree.isLeaft=true;
int count = maxClassName(D);
if(count>(D.size()-count))
tree.type= "是";
else
tree.type="否";
return;
}
else{
int count = maxClassName(D);
if(count>(D.size()-count))
tree.type= "是";
else
tree.type="否";
ArrayList<ArrayList<String>> Aa = new ArrayList<ArrayList<String>>();
for(int i=0;i<A.size();i++){
if(!A.get(wellIndex).get(0).equals(A.get(i).get(0)))
Aa.add(A.get(i));
}
int pp = node.attributes.length;
TreeGenerate(Dv, Aa,tree);
}
}
}
/**
* 获取最优a*
* @param D
* @param A
* @return
*/
@SuppressWarnings("unchecked")
private String selectAttribute(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a) {
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();
int maxIndex=0;//a*的位置
double max=0.00;//gain最大值
for(int i=0;i<A.size();i++){
double dou = gain(D,A,i);
System.out.println("gain["+A.get(i).get(0)+"]="+dou+" ");
if(dou>max){
max=dou;
maxIndex =i;
}
}
System.out.println("属性:"+A.get(maxIndex).get(0)+" 位置:"+maxIndex);
return A.get(maxIndex).get(0);
}
/**
* 获取gain
* @param d
* @param i
* @return
*/
@SuppressWarnings("unchecked")
private double gain(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a, int aIndex) {
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();
double gainnum=0.00;
double ent=0.00;
double wEntDv=0.00;
int V = A.get(aIndex).size();
ent = getEnt(D);
for(int v=1;v<V;v++){
ArrayList<ArrayList<String>> Dv=(ArrayList<ArrayList<String>>) getDv(D, A, aIndex, v).clone();
if(Dv.size()==0){
wEntDv +=0.0;
}else
{
wEntDv +=(((double)Dv.size()/(double)D.size())*getEnt(Dv));
}
}
gainnum = ent-wEntDv;
return gainnum;
}
/**
* 获取Dv
* @param D
* @param aIndex
* @return
*/
@SuppressWarnings("unchecked")
private ArrayList<ArrayList<String>> getDv(ArrayList<ArrayList<String>> d,ArrayList<ArrayList<String>> a,int aIndex,int avIndex){
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
ArrayList<ArrayList<String>> A =(ArrayList<ArrayList<String>>) a.clone();
ArrayList<ArrayList<String>> Dv = new ArrayList<ArrayList<String>>();
Iterator it = D.iterator();
int aPosition = 0;
if(A.get(aIndex).get(0).equals("色泽"))
aPosition=0;
else if(A.get(aIndex).get(0).equals("根蒂"))
aPosition=1;
else if(A.get(aIndex).get(0).equals("敲声"))
aPosition=2;
else if(A.get(aIndex).get(0).equals("纹理"))
aPosition=3;
else if(A.get(aIndex).get(0).equals("脐部"))
aPosition=4;
else
aPosition=5;
int index =0;
while(it.hasNext()){
if(index>=D.size()){break;}
if(A.get(aIndex).get(avIndex).equals(D.get(index).get(aPosition))){
Dv.add(D.get(index));
}
index++;
}
return Dv;
}
/**
* ENT的获取
* @param D
* @return
*/
@SuppressWarnings("unchecked")
private double getEnt(ArrayList<ArrayList<String>> d){
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
double ent=0.00;
double p1=0.00;
double p2=0.00;
double log1=0.00;
double log2=0.00;
int goodNum = maxClassName(D);
int DNum = D.size();
p1=(double)goodNum/DNum;
p2=(double)(DNum-goodNum)/DNum;
if(p1==0.0||p2==0.0){
ent = 0;
}else{
log1=(double)Math.log(p1)/(Math.log(2));
log2=(double)Math.log(p2)/(Math.log(2));
ent=-(p1*log1+p2*log2);
}
return ent;
}
/**
* 返回好瓜个数
* @param D
* @return String
*/
@SuppressWarnings("unchecked")
public int maxClassName(ArrayList<ArrayList<String>> d){
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
int count=0;//好瓜个数
int xLength =0;
xLength=D.get(0).size();//样本大小
Iterator it = D.iterator();
int index =0;
while(it.hasNext()){
if(index >= D.size()){break;}
if(D.get(index).get(xLength-1).toString().equals("是"))
count++;
index++;
}
return count;
}
/**
* 判断D中样本在A上的取值是否都一样
* @param D
* @return
*/
@SuppressWarnings("unchecked")
public boolean isDxarributeEquelalsA(ArrayList<ArrayList<String>> d){
ArrayList<ArrayList<String>> D =(ArrayList<ArrayList<String>>) d.clone();
boolean isTrue=false;
Iterator it = D.iterator();
int xLength = D.get(0).size();//样本长度
int index = 0;//为样本循环计数
Object[] astr = D.get(0).toArray();
while(it.hasNext()){
if(index>=D.size()){break;}
Iterator it1 = D.get(index).iterator();
int index1=0;
while(it1.hasNext()){
if(index1>=(D.get(index).size())){break;}
if(!D.get(index).get(index1).equals(D.get(0).get(index1)))
{
return false;
}
index1++;
}
index++;
}
return true;
}
/**
* 判断样本是否都属于同一个类别
* @param trainset
* @return boolean
*/
public boolean isDclassOnlyOne(ArrayList<ArrayList<String>> D){
ArrayList<ArrayList<String>> trainset=(ArrayList<ArrayList<String>>) D.clone();
Iterator it = trainset.iterator();
int xLength = trainset.get(0).size();//样本长度
int index = 0;//为样本循环计数
String vertiftAttrib = trainset.get(index).get(xLength-1).toString();
while(it.hasNext()){
if(index>=trainset.size()){break;}
if(!vertiftAttrib.equals(trainset.get(index).get(xLength-1).toString())){
return false;
}
index++;
}
return true;
}
}
import java.io.FileInputStream;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.poi.hssf.usermodel.HSSFCell;
import org.apache.poi.hssf.usermodel.HSSFRow;
import org.apache.poi.hssf.usermodel.HSSFSheet;
import org.apache.poi.hssf.usermodel.HSSFWorkbook;
@SuppressWarnings("restriction")
public class ListTrainSet {
/**
* xls工作簿的读取
* @param FileInputStream fis
* @return ArrayList<ArrayList<String>>
* @throws IOException
*/
public ArrayList<ArrayList<String>> readExcle(FileInputStream fis) throws IOException{
ArrayList<ArrayList<String>> aas = new ArrayList<ArrayList<String>>();
HSSFWorkbook hwb = new HSSFWorkbook(fis);//创建工作簿对象
//获取每一个工作簿
for(int numSheet=0;numSheet<hwb.getNumberOfSheets();numSheet++){
HSSFSheet hssfSheet = hwb.getSheetAt(numSheet);
if(hssfSheet == null){
continue;
}
//获取当前工作簿的每一行
for(int rowNum = 0;rowNum<=hssfSheet.getLastRowNum();rowNum++){
HSSFRow hssfRow = hssfSheet.getRow(rowNum);//创建行对象
ArrayList<String> as =new ArrayList<String>();
if(hssfRow!=null){
int cellnum =0;
while(hssfRow.getCell(cellnum)!=null){//获取每行中的信息(即每个细胞单元)
HSSFCell one = hssfRow.getCell(cellnum);//创建细胞对象
as.add(getValue(one).toString());
cellnum++;
}
aas.add(as);//添加到ArrayList<ArrayList<String>>
}
}
}
return aas;
}
/**
* 转换数据格式
* @param HSSFCell hssfCell
* @return String
*/
@SuppressWarnings({ "static-access" })
public String getValue(HSSFCell hssfCell){
if(hssfCell.getCellType()==hssfCell.CELL_TYPE_BOOLEAN)
return String.valueOf(hssfCell.getBooleanCellValue());
else if(hssfCell.getCellType()==hssfCell.CELL_TYPE_NUMERIC)
return String.valueOf(hssfCell.getNumericCellValue());
else{
return String.valueOf(hssfCell.getStringCellValue());
}
}
}
import java.util.LinkedList;
/**
* 层次遍历
* @author FES
*
*/
public class LevelOrder {
public void levelIterator(TreeNode root){
if(root == null){
return;
}
LinkedList<TreeNode> queue = new LinkedList<TreeNode>();
TreeNode current = null;
queue.offer(root);
while(!queue.isEmpty()){
current = queue.poll();
if(current.isLeaft==false){
System.out.println(current.aName);
}else{
System.out.println(current.type);
}
if(current.childnext!=null){
int n = current.childnext.length;
for(int i=0;i<n;i++){
queue.offer(current.childnext[i]);
}
}
}
}
}
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.ArrayList;
public class Test {
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
ListTrainSet ls = new ListTrainSet();
File f1 = new File("E:\\text\\D.xls");
File f2 = new File("E:\\text\\A.xls");
File f3 = new File("E:\\text\\T.xls");
FileInputStream fin1 = new FileInputStream(f1);
FileInputStream fin2 = new FileInputStream(f2);
FileInputStream fin3 = new FileInputStream(f3);
ArrayList<ArrayList<String>> D = ls.readExcle(fin1);
ArrayList<ArrayList<String>> A = ls.readExcle(fin2);
ArrayList<ArrayList<String>> T = ls.readExcle(fin3);
System.out.println("D的大小:"+D.size());
System.out.println("A的大小:"+A.size());
TreeNode td = new TreeNode();
TreeGenerateFunction tgf = new TreeGenerateFunction();
tgf.TreeGenerate(D, A,td);
new LevelOrder().levelIterator(td);
}
}
“`
参考文献
- 周志华, 杨强. 机器学习及其应用[M]. 清华大学出版社, 2011.