使用spark ML创建机器学习流水线,ml包和mllib包的区别
spark中ml包和mllib包的区别
- mllib,主要针对RDD
- ml,主要针对dataSet
-
建议使用ml,它比mllib新,而且dataSet可用spark SQL操作,比较灵活.
下面是一个机器学习的Demo,使用DataFrame作为数据集
package edu.zhku.mllib.base
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.SparkSession
/*
* @author : 钱伟健 [email protected]
* @version : 2018/4/30 20:24.
* 说明:
*/
/**
* <pre>使用ML创建机器学习流水线</pre>
*/
object MLAssemblyLine {
// 创建一个特征Cass类
case class Feature(v: Vector)
System.setProperty("hadoop.home.dir", "D:\\program\\hadoop")
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName("HotReView")
.master("local[2]")
.getOrCreate()
import spark.implicits._
/**
* 判断某人是否是篮球运动员
*/
// 定义啊,a, b, c, d 四个人,其中d并非篮球运动员,特征分别是身高和体重
val a = LabeledPoint(1.0,Vectors.dense(80.0, 250.0))
val b = LabeledPoint(0.0, Vectors.dense(70.0, 150.0))
val c = LabeledPoint(1.0, Vectors.dense(80.0, 207.0))
val d = LabeledPoint(0.0, Vectors.dense(65.0, 120.0))
// 创建一个训练集
val trainingRDD = spark.sparkContext.parallelize(List(a, b, c, d))
// 创建一个DataFrame
val trainingDF = trainingRDD.toDF()
// 创建一个LogisticRegresssion估算
val estimator = new LogisticRegression
// 创建一个拟合故居算和训练集DAtaFrame的转换器
val transformer = estimator.fit(trainingDF)
// 创建测试数据,篮球运动员 i
val i = Vectors.dense(90.0, 270.0)
// 另一个非篮球运动员 j
val j = Vectors.dense(62.0, 120.0)
//创建训练集RDD
val testRDD = spark.sparkContext.parallelize(List(i,j))
// 映射testRDD和featureRDD
val featuresRDD = testRDD.map( v => Feature(v))
// 将featuresRDD 转换为列名叫"features"的DataFrame
val featuresDF = featuresRDD.toDF("features")
// 在featureDF中增加预测列
val predictionsDF = transformer.transform(featuresDF)
// 打印predictionsDF
predictionsDF.collect().foreach(println)
// PredictionsDF新增了3列--行预测,可能性和预测.我们只选择特性和预测列
val shorterPredictionsDF = predictionsDF.select("features","prediction")
// 将预测列重命名为isBasketBallPlayer
val playerDF = shorterPredictionsDF.toDF("features","isBasketBallPlayer")
// 打印playerDF的数据结构
playerDF.printSchema()
println("==========结果===========")
playerDF.collect().foreach(println)
}
}