决策树(C4.5) scala代码实现
maven依赖
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib_2.11</artifactId>
<version>2.1.1</version>
</dependency>
数据集
原数据集是GitHub上的, 下载速度很慢, 可以使用我下载好的数据集
密码:8n5z
源码(scala)
package com.train.mllib
import org.apache.spark.ml.{
Pipeline, PipelineModel, PipelineStage, Transformer}
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{
IndexToString, StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{
DataFrame, Dataset, Row, SparkSession}
/**
* @author nineice
* @date 2021/4/22 8:23
* TODO
*/
object Demo11 {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.master("local[*]")
.appName(Demo11.getClass.getSimpleName)
.getOrCreate()
/*加载数据集*/
val data_path: String = "C:\\Users\\nineice\\Desktop\\SparkTemp\\input\\mlData\\sample_multiclass_classification_datas.txt"
val df: DataFrame = spark.read.format("libsvm").load(data_path)
//以7:3的比例随机分隔为训练数据集和测试数据集
val splitData: Array[Dataset[Row]] = df.randomSplit(Array(0.7, 0.3))
val train: Dataset[Row] = splitData(0)
val test: Dataset[Row] = splitData(1)
// train.show()
// test.show()
/*建立训练Pipeline*/
//1. 对label进行重新编号
val labelIndexerModel: StringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.setHandleInvalid("skip")
.fit(df)
//2. 对特征向量进行重新编号
val featureIndexerModel: VectorIndexerModel = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(5)
.fit(df)
//3. 决策树分类器参数设置
for (maxDepth <- 2 to 10) {
val dtClassifier: DecisionTreeClassifier = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setMaxDepth(maxDepth)
.setImpurity("entropy")
//4. 将编号后的预测label转换回来
val converter: IndexToString = new IndexToString()
.setInputCol("prediction") // 自动产生的预测label行名字
.setOutputCol("convetedPrediction")
.setLabels(labelIndexerModel.labels)
//拼接Pipeline
val pipeline: Pipeline = new Pipeline()
.setStages(Array(
labelIndexerModel,
featureIndexerModel,
dtClassifier,
converter))
//使用训练集训练pipeline模型
val pipelineModel: PipelineModel = pipeline.fit(train)
/*多分类结果评估*/
//预测
val testPrediction: DataFrame = pipelineModel.transform(test)
val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
//评估
println("MaxDepth is : " + maxDepth)
val accuracy: Double = evaluator.evaluate(testPrediction)
println("accuracy is : " + accuracy)
//输出决策树模型
// println(pipelineModel.stages(0))
// println(pipelineModel.stages(1))
println(pipelineModel.stages(2))
// println(pipelineModel.stages(3))
println("=" * 100)
}
}
}
结果
一般来说, 由于我们的训练/测试数据集是随机分配的, 所以每次的结果都不一定相同
最理想的情况中, 我们会得到一个随着深度的深入准确率逐渐接近于1的结果
参考
https://www.cnblogs.com/itboys/p/8312894.html
https://www.cnblogs.com/xiguage119/p/11015677.html