xgboost学习样例解析之binary classification

玩kaggle时听说这个库很牛逼,准确率很高,于是打算学学看看。xgboost是extreme gradient boosting的缩写,gradient boosting是一类思想方法,这里不多赘述,可以参考这里:gradient boosting简介


本文主要解析实例,理论以后慢慢补上。

binary classification的例子官方原文在此: 点击打开链接


先说例子做的是一件什么事情,很简单,给你一堆有label的包含各种属性的蘑菇数据集,label只有两种: 有毒p或者可食用e。蘑菇有22种属性,每种属性有若干可取值,官方描述如下:

7. Attribute Information: (classes: edible=e, poisonous=p)
     1. cap-shape:                bell=b,conical=c,convex=x,flat=f,
                                  knobbed=k,sunken=s
     2. cap-surface:              fibrous=f,grooves=g,scaly=y,smooth=s
     3. cap-color:                brown=n,buff=b,cinnamon=c,gray=g,green=r,
                                  pink=p,purple=u,red=e,white=w,yellow=y
     4. bruises?:                 bruises=t,no=f
     5. odor:                     almond=a,anise=l,creosote=c,fishy=y,foul=f,
                                  musty=m,none=n,pungent=p,spicy=s
     6. gill-attachment:          attached=a,descending=d,free=f,notched=n
     7. gill-spacing:             close=c,crowded=w,distant=d
     8. gill-size:                broad=b,narrow=n
     9. gill-color:               black=k,brown=n,buff=b,chocolate=h,gray=g,
                                  green=r,orange=o,pink=p,purple=u,red=e,
                                  white=w,yellow=y
    10. stalk-shape:              enlarging=e,tapering=t
    11. stalk-root:               bulbous=b,club=c,cup=u,equal=e,
                                  rhizomorphs=z,rooted=r,missing=?
    12. stalk-surface-above-ring: fibrous=f,scaly=y,silky=k,smooth=s
    13. stalk-surface-below-ring: fibrous=f,scaly=y,silky=k,smooth=s
    14. stalk-color-above-ring:   brown=n,buff=b,cinnamon=c,gray=g,orange=o,
                                  pink=p,red=e,white=w,yellow=y
    15. stalk-color-below-ring:   brown=n,buff=b,cinnamon=c,gray=g,orange=o,
                                  pink=p,red=e,white=w,yellow=y
    16. veil-type:                partial=p,universal=u
    17. veil-color:               brown=n,orange=o,white=w,yellow=y
    18. ring-number:              none=n,one=o,two=t
    19. ring-type:                cobwebby=c,evanescent=e,flaring=f,large=l,
                                  none=n,pendant=p,sheathing=s,zone=z
    20. spore-print-color:        black=k,brown=n,buff=b,chocolate=h,green=r,
                                  orange=o,purple=u,white=w,yellow=y
    21. population:               abundant=a,clustered=c,numerous=n,
                                  scattered=s,several=v,solitary=y
    22. habitat:                  grasses=g,leaves=l,meadows=m,paths=p,
                                  urban=u,waste=w,woods=d

官方格式的数据集长这样:

e,x,y,y,t,l,f,c,b,n,e,r,s,y,w,w,p,w,o,p,n,s,g
p,f,s,n,t,p,f,c,n,p,e,e,s,s,w,w,p,w,o,p,k,s,g
e,f,f,w,f,n,f,w,b,h,t,e,f,s,w,w,p,w,o,e,n,s,g
e,f,f,n,f,n,f,w,b,n,t,e,f,f,w,w,p,w,o,e,k,a,g
p,f,s,n,t,p,f,c,n,w,e,e,s,s,w,w,p,w,o,p,n,v,g
e,x,s,n,f,n,f,w,b,k,t,e,f,f,w,w,p,w,o,e,n,a,g
e,x,f,n,f,n,f,w,b,n,t,e,s,f,w,w,p,w,o,e,k,s,g
e,x,f,e,t,n,f,c,b,u,t,b,s,s,g,w,p,w,o,p,k,v,d
e,f,f,n,f,n,f,w,b,k,t,e,s,s,w,w,p,w,o,e,k,a,g

扫描二维码关注公众号,回复: 1875792 查看本文章

第一列为lable,后面22列为属性值

看起来很规整对吧?对于xgboost cli来说,还不够规整,xgboost期望的数据格式长这样:

0 3:1 9:1 20:1 21:1 24:1 34:1 36:1 39:1 45:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1
0 3:1 9:1 20:1 21:1 23:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 122:1
0 1:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 51:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 117:1 120:1
1 3:1 9:1 19:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 118:1 124:1
0 3:1 7:1 11:1 22:1 29:1 34:1 37:1 39:1 42:1 54:1 58:1 65:1 66:1 77:1 86:1 88:1 92:1 95:1 98:1 105:1 114:1 120:1

这是什么鬼?看过的朋友就知道这是libsvm格式,简单来说第一列为label,后面的 A:B中的A代表feature index, B代表feature值,值为0的直接被略掉了,所以称为稀疏表示法。具体可以参考这里

而由于每个蘑菇feature可以有超过2种可能值,为了方便,我们先转换成独热编码的形式,再转化成上面这种需要的形式,思路上第一步就是这样,上代码:

mapfeat.py:

  1. #!/usr/bin/python
  2. def loadfmap( fname ):
  3. fmap = {}
  4. nmap = {}
  5. for l in open( fname ):
  6. arr = l.split()
  7. if arr[ 0].find( '.') != -1:
  8. idx = int( arr[ 0].strip( '.') )
  9. assert idx not in fmap
  10. fmap[ idx ] = {}
  11. ftype = arr[ 1].strip( ':')
  12. content = arr[ 2]
  13. else:
  14. content = arr[ 0]
  15. for it in content.split( ','):
  16. if it.strip() == '':
  17. continue
  18. k , v = it.split( '=')
  19. fmap[ idx ][ v ] = len(nmap) + 1
  20. nmap[ len(nmap) ] = ftype+ '='+k
  21. return fmap, nmap
  22. def write_nmap( fo, nmap ):
  23. for i in range( len(nmap) ):
  24. fo.write( '%d\t%s\ti\n' % (i, nmap[i]) )
  25. # start here
  26. fmap, nmap = loadfmap( 'agaricus-lepiota.fmap' )
  27. fo = open( 'featmap.txt', 'w' )
  28. write_nmap( fo, nmap )
  29. fo.close()
  30. fo = open( 'agaricus.txt', 'w' )
  31. for l in open( 'agaricus-lepiota.data' ):
  32. arr = l.split( ',')
  33. if arr[ 0] == 'p':
  34. fo.write( '1')
  35. else:
  36. assert arr[ 0] == 'e'
  37. fo.write( '0')
  38. for i in range( 1,len(arr) ):
  39. fo.write( ' %d:1' % fmap[i][arr[i].strip()] )
  40. fo.write( '\n')
  41. fo.close()

不长,也很简单,就是转成独热编码的libsvm表示,都是些字符串处理的逻辑,不多废话了,结果存储到agaricus-lepiota.data。


接下来,把所有的输入数据拆分成training set 和 test set 

mknfold.py:

  1. #!/usr/bin/python
  2. import sys
  3. import random
  4. #==fc==
  5. if len(sys.argv) < 2:
  6. print ( 'Usage:<filename> <k> [nfold = 5]')
  7. exit( 0)
  8. random.seed( 10 )
  9. k = int( sys.argv[ 2] )
  10. if len(sys.argv) > 3:
  11. nfold = int( sys.argv[ 3] )
  12. else:
  13. nfold = 5
  14. fi = open( sys.argv[ 1], 'r' )
  15. ftr = open( sys.argv[ 1]+ '.train', 'w' )
  16. fte = open( sys.argv[ 1]+ '.test', 'w' )
  17. for l in fi:
  18. if random.randint( 1 , nfold ) == k:
  19. fte.write( l )
  20. else:
  21. ftr.write( l )
  22. fi.close()
  23. ftr.close()
  24. fte.close()

逻辑也不复杂,就是通过随机数的一种巧妙方法把数据集做一个随机拆分,分别写到两个文件里。

至此,数据集准备完毕。下面开始用xgboost训练, 这里用cli版本的调用如下:

../../xgboost mushroom.conf
mushroom.conf是什么鬼?

其实就是一个跑xgboost的配置文件,里面包含这些等等:

# General Parameters, see comment for each definition
# choose the booster, can be gbtree or gblinear
booster = gbtree
# choose logistic regression loss function for binary classification
objective = binary:logistic

# Tree Booster Parameters
# step size shrinkage
eta = 1.0
# minimum loss reduction required to make a further partition
gamma = 1.0
# minimum sum of instance weight(hessian) needed in a child
min_child_weight = 1
# maximum depth of a tree
max_depth = 3


完整的参数配置参考 这里


接下来可以用这种命令:../../xgboost mushroom.conf task=dump model_in=0002.model name_dump=dump.raw.txt 把训练出来的树dump出来, 树大概长这样:

booster[0]:
0:[f29<-9.53674e-07] yes=1,no=2,missing=1
        1:[f56<-9.53674e-07] yes=3,no=4,missing=3
                3:[f60<-9.53674e-07] yes=7,no=8,missing=7
                        7:leaf=1.90175
                        8:leaf=-1.95062
                4:[f21<-9.53674e-07] yes=9,no=10,missing=9
                        9:leaf=1.77778
                        10:leaf=-1.98104
        2:[f109<-9.53674e-07] yes=5,no=6,missing=5
                5:[f67<-9.53674e-07] yes=11,no=12,missing=11
                        11:leaf=-1.98531
                        12:leaf=0.808511
                6:leaf=1.85965
booster[1]:
0:[f29<-9.53674e-07] yes=1,no=2,missing=1
        1:[f21<-9.53674e-07] yes=3,no=4,missing=3
                3:leaf=1.1457
                4:[f36<-9.53674e-07] yes=7,no=8,missing=7
                        7:leaf=-6.87558
                        8:leaf=-0.127376
        2:[f109<-9.53674e-07] yes=5,no=6,missing=5
                5:[f39<-9.53674e-07] yes=9,no=10,missing=9
                        9:leaf=-0.0386054
                        10:leaf=-1.15275
                6:leaf=0.994744



完整的执行顺序如下:

#!/bin/bash
# map feature using indicator encoding, also produce featmap.txt
python mapfeat.py
# split train and test
python mknfold.py agaricus.txt 1
# training and output the models
../../xgboost mushroom.conf
# output prediction task=pred 
../../xgboost mushroom.conf task=pred model_in=0002.model
# print the boosters of 00002.model in dump.raw.txt
../../xgboost mushroom.conf task=dump model_in=0002.model name_dump=dump.raw.txt 
# use the feature map in printing for better visualization
../../xgboost mushroom.conf task=dump model_in=0002.model fmap=featmap.txt name_dump=dump.nice.txt
cat dump.nice.txt

猜你喜欢

转载自blog.csdn.net/kwame211/article/details/80900536