tree源码分析
ml.tree.impl
DecisionTreeMetadata.scala
构造决策树元信息
def buildMetadata(
input: RDD[LabeledPoint],
strategy: Strategy): DecisionTreeMetadata = {
buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all")
}
def buildMetadata(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String): DecisionTreeMetadata = {
具体做了以下工作:
(1) 取出特征总数:numFeatures( 取出RDD元素(标签点)中本地向量的最大长度);
(2) 取出trainingData实例数:numExamples ;
(3) 取出类别数量numClasses:=2(二分类) >2 (多分类) =0 (回归)
(4) 取出maxPossibleBins:math.min(strategy.maxBins, numExamples)
注:maxPossibleBins <= numExamples(必须满足,否则警告)
(5) 针对离散特征 categoricalFeaturesInfo: Map[k, n]
maxCategoriesPerFeature = Max(n)
maxCategory = k(k.value = Max(n))
(6) 定义两个集合:
val unorderedFeatures = new mutable.HashSet[Int]()
val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
多分类:
通过设置的参数maxBins,得到的maxPossibleBins 【2 * ((1 << numCategories - 1) - 1)】
得到Max(numCategories),以此值作为判断有序或者无序临界值。
遍历categoricalFeaturesInfo: Map[k, n]
如果n <= Max(numCategories),则无序对待
把k加入无序特征集合HashSet---> unorderedFeatures.add(featureIndex)
把由n得到的numUnorderBins加入数组numBins--->numBins(featureIndex) = numUnorderedBins(numCategories)
如果n > Max(numCategories),则有序对待
numBins(featureIndex) = numCategories
说明:无论有序或者无序,都要把Bins加入数组numBins指定索引处,只有同时存在HashSet集合中的才是无序对待的。
二分类或者回归:
// If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957
if (numCategories > 1) {
numBins(featureIndex) = numCategories
}
(7) 设置每个节点使用的特征数(针对随机森林)
}
dfs
RandomForest.scala
run方法
训练随机森林,返回数组(数组元素是决策树模型)
具体做了以下工作:
(1)构建决策树元信息;
(2)Find the splits and the corresponding bins (interval between the splits) using a sample of the input data;
(3)
findSplits方法
findSplits:
protected[tree] def findSplits(
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
seed: Long): Array[Array[Split]] = {
logDebug("isMulticlass = " + metadata.isMulticlass)
val numFeatures = metadata.numFeatures
// 连续特征Vector(......)
val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
// Sample the input only if there are continuous features.
val sampledInput = if (continuousFeatures.nonEmpty) {
// 有连续特征
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
requiredSamples.toDouble / metadata.numExamples
} else {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
// 不放会采样
input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else {
// 没有连续特征,采样为空RDD
input.sparkContext.emptyRDD[LabeledPoint]
}
findSplitsBySorting(sampledInput, metadata, continuousFeatures)
}
findSplitsBySorting方法
findSplitsBySorting
// 对连续特征进行切分
private def findSplitsBySorting(
// 无放回抽样子集
input: RDD[LabeledPoint],
metadata: DecisionTreeMetadata,
continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
val continuousSplits: scala.collection.Map[Int, Array[Split]] = {
// reduce the parallelism for split computations when there are less
// continuous features than input partitions. this prevents tasks from
// being spun up that will definitely do no work.
val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
input
.flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
.groupByKey(numPartitions)
.map { case (idx, samples) =>
val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
(idx, splits)
}.collectAsMap()
}
val numFeatures = metadata.numFeatures
val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
case i if metadata.isContinuous(i) =>
val split = continuousSplits(i)
metadata.setNumSplits(i, split.length)
split
case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
// Unordered features
// 2^(maxFeatureValue - 1) - 1 combinations
val featureArity = metadata.featureArity(i)
Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
new CategoricalSplit(i, categories.toArray, featureArity)
}
case i if metadata.isCategorical(i) =>
// Ordered features
// Splits are constructed as needed during training.
Array.empty[Split]
}
splits
}
TreePoint.scala
labeledPointToTreePoint方法
mllib.tree
DecisionTree.scala
1. 入口点
返回值:DecisionTreeModel
def trainClassifier(
input: RDD[LabeledPoint],
numClasses: Int,
categoricalFeaturesInfo: Map[Int, Int],
impurity: String,
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
val impurityType = Impurities.fromString(impurity)
train(input, Classification, impurityType, maxDepth, numClasses, maxBins, Sort,
categoricalFeaturesInfo)
}
2. 跟进train方法
- 构造决策树参数
- 通过提供的数据,训练决策树模型(可以用于预测)
提供配置参数->决策树==
def train(
input: RDD[LabeledPoint],
algo: Algo,
impurity: Impurity,
maxDepth: Int,
numClasses: Int,
maxBins: Int,
quantileCalculationStrategy: QuantileStrategy,
categoricalFeaturesInfo: Map[Int, Int]): DecisionTreeModel = {
val strategy = new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo)
new DecisionTree(strategy).run(input)
}
2. 跟进run方法
- 创建随机森林
- 训练随机森林模型(有n个决策树模型构成)
- 取出第一个决策树模型
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)
val rfModel = rf.run(input)
rfModel.trees(0)
}
RandomForest.scala
1. run方法
def run(input: RDD[LabeledPoint]): RandomForestModel = {
val trees: Array[NewDTModel] = NewRandomForest.run(input.map(_.asML), strategy, numTrees,
featureSubsetStrategy, seed.toLong, None)
new RandomForestModel(strategy.algo, trees.map(_.toOld))
}