机器学习结果加ID插入数据库源码

import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics

import org.apache.spark.mllib.linalg.Vectors

import org.apache.spark.mllib.regression.LabeledPoint

import org.apache.spark.mllib.tree.GradientBoostedTrees

import org.apache.spark.mllib.tree.configuration.BoostingStrategy

import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel

import org.apache.spark.sql.{Row, SaveMode}

import org.apache.spark.sql.hive.HiveContext

import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType}

import org.apache.spark.{SparkConf, SparkContext}

import scala.collection.mutable.ArrayBuffer

object v4score20180123 {

  def main(args: Array[String]): Unit = {

  val sparkConf = new SparkConf().setAppName("v4model20180123")

  val sc = new SparkContext(sparkConf)

  val hc = new HiveContext(sc)

  val dataInstance = hc.sql(s"select * from lkl_card_score.fqz_score_dataset_04vals").map {

    row =>

      val arr = new ArrayBuffer[Double]()

      //剔除label、phone字段

      for (i <- 3 until row.size) {

        if (row.isNullAt(i)) {

          arr += 0.0

        }

        else if (row.get(i).isInstanceOf[Int])

          arr += row.getInt(i).toDouble

        else if (row.get(i).isInstanceOf[Double])

          arr += row.getDouble(i)

        else if (row.get(i).isInstanceOf[Long])

          arr += row.getLong(i).toDouble

        else if (row.get(i).isInstanceOf[String])

          arr += 0.0

      }

      (row(0),row(1),row(2),Vectors.dense(arr.toArray))

  }

  val  modeltest=GradientBoostedTreesModel.load(sc,s"hdfs://ns1/user/songchunlin/model/v4model20180123s")

  val preditDataGBDT = dataInstance.map { point =>

    val prediction = modeltest.predict(point._4)

    //order_id,apply_time,score

    (point._1,point._2,point._3,prediction)

  }

  preditDataGBDT.take(5)

  //rdd转dataFrame

  val rowRDD = preditDataGBDT.map(row => Row(row._1.toString,row._2.toString,row._3.toString,row._4))

  val schema = StructType(

    List(

      StructField("order_id", StringType, true),

      StructField("apply_time", StringType, true),

      StructField("label", StringType, true),

      StructField("score", DoubleType, true)

    )

  )

  //将RDD映射到rowRDD,schema信息应用到rowRDD上

  val scoreDataFrame = hc.createDataFrame(rowRDD,schema)

  scoreDataFrame.count()

  scoreDataFrame.write.mode(SaveMode.Overwrite).saveAsTable("lkl_card_score.fqz_score_dataset_03val_v4_predict0123s")

}

}

猜你喜欢

转载自blog.csdn.net/hellozhxy/article/details/84256808