向导
MAVEN
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-sparkml</artifactId>
<version>1.5.0</version>
<!-- <exclusions>-->
<!-- <exclusion>-->
<!-- <groupId>org.jpmml</groupId>-->
<!-- <artifactId>jpmml-converter</artifactId>-->
<!-- </exclusion>-->
<!-- </exclusions>-->
</dependency>
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-lightgbm</artifactId>
<version>1.2.3</version>
</dependency>
测试数据
http://archive.ics.uci.edu/ml/machine-learning-databases/00275/Bike-Sharing-Dataset.zip
hour.csv和day.csv都有如下属性,除了hour.csv文件中没有hr属性以外
- instant: 记录ID
- dteday : 时间日期
- season : 季节 (1:春季, 2:夏季, 3:秋季, 4:冬季)
- yr : 年份 (0: 2011, 1:2012)
- mnth : 月份 ( 1 to 12)
- hr : 当天时刻 (0 to 23)
- holiday : 当天是否是节假日(extracted from http://dchr.dc.gov/page/holiday-schedule)
- weekday : 周几
- workingday : 工作日 is 1, 其他 is 0.
- weathersit : 天气
- 1: Clear, Few clouds, Partly cloudy, Partly cloudy
- 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
- 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
- 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog
- temp : 气温 Normalized temperature in Celsius. The values are divided to 41 (max)
- atemp: 体感温度 Normalized feeling temperature in Celsius. The values are divided to 50 (max)
- hum: 湿度 Normalized humidity. The values are divided to 100 (max)
- windspeed: 风速Normalized wind speed. The values are divided to 67 (max)
- casual: 临时用户数count of casual users
- registered: 注册用户数count of registered users
- cnt: 目标变量,每小时的自行车的租用量,包括临时用户和注册用户count of total rental bikes including both casual and registered
代码示例,以二分类为例
package com.bigblue.lightgbm
import java.io.FileOutputStream
import com.bigblue.utils.LightGBMUtils
import com.microsoft.ml.spark.lightgbm.{LightGBMClassificationModel, LightGBMClassifier}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.jpmml.lightgbm.GBDT
import org.jpmml.model.MetroJAXBUtil
/**
* Created By TheBigBlue on 2020/3/6
* Description :
*/
object LightGBMClassificationTest {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("test-lightgbm").master("local[2]").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
var originalData: DataFrame = spark.read.option("header", "true") //第一行作为Schema
.option("inferSchema", "true") //推测schema类型
// .csv("/home/hdfs/hour.csv")
.csv("file:///D:/Cache/ProgramCache/TestData/dataSource/lightgbm/hour.csv")
val labelCol = "workingday"
//离散列
val cateCols = Array("season", "yr", "mnth", "hr")
// 连续列
val conCols: Array[String] = Array("temp", "atemp", "hum", "casual", "cnt")
//feature列
val vecCols = conCols ++ cateCols
import spark.implicits._
vecCols.foreach(col => {
originalData = originalData.withColumn(col, $"$col".cast(DoubleType))
})
originalData = originalData.withColumn(labelCol, $"$labelCol".cast(IntegerType))
val assembler = new VectorAssembler().setInputCols(vecCols).setOutputCol("features")
val classifier: LightGBMClassifier = new LightGBMClassifier().setNumIterations(100).setNumLeaves(31)
.setBoostFromAverage(false).setFeatureFraction(1.0).setMaxDepth(-1).setMaxBin(255)
.setLearningRate(0.1).setMinSumHessianInLeaf(0.001).setLambdaL1(0.0).setLambdaL2(0.0)
.setBaggingFraction(1.0).setBaggingFreq(0).setBaggingSeed(1).setObjective("binary")
.setLabelCol(labelCol).setCategoricalSlotNames(cateCols).setFeaturesCol("features")
.setBoostingType("gbdt") //rf、dart、goss
val pipeline: Pipeline = new Pipeline().setStages(Array(assembler, classifier))
val Array(tr, te) = originalData.randomSplit(Array(0.7, .03), 666)
val model = pipeline.fit(tr)
val modelDF = model.transform(te)
val evaluator = new BinaryClassificationEvaluator().setLabelCol(labelCol).setRawPredictionCol("prediction")
println(evaluator.evaluate(modelDF))
//增加导出pmml
val classificationModel = model.stages(1).asInstanceOf[LightGBMClassificationModel]
LightGBMUtils.saveToPmml(classificationModel.getModel, "D://Download/classificationModel.xml")
}
}
package com.bigblue.utils
import java.io.{ByteArrayInputStream, FileOutputStream}
import com.microsoft.ml.spark.lightgbm.LightGBMBooster
import org.jpmml.lightgbm.LightGBMUtil
import org.jpmml.model.MetroJAXBUtil
/**
* Created By TheBigBlue on 2020/3/20
* Description :
*/
object LightGBMUtils {
def saveToPmml(booster: LightGBMBooster, path: String): Unit = {
try {
val gbdt = LightGBMUtil.loadGBDT(new ByteArrayInputStream(booster.model.getBytes))
import scala.collection.JavaConversions.mapAsJavaMap
val pmml = gbdt.encodePMML(null, null, Map("compact" -> true))
MetroJAXBUtil.marshalPMML(pmml, new FileOutputStream(path))
} catch {
case e: Exception => e.printStackTrace()
}
}
}