用户定义函数(User-defined functions, UDFs)是大多数 SQL 环境的关键特性,用于扩展系统的内置功能。 UDF允许开发人员通过抽象其低级语言实现来在更高级语言(如SQL)中启用新功能。 Apache Spark 也不例外,并且提供了用于将 UDF 与 Spark SQL工作流集成的各种选项。
本文通过自定义UDF实现WordCount案例:
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object UDF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("UDF").master("local[2]").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
val bigData = Array("Spark", "Spark", "Hadoop", "Spark", "Hadoop", "Spark", "Spark", "Hadoop", "Spark", "Hadoop")
val bigDataRDD: RDD[String] = sc.parallelize(bigData)
val bigDataRDDRow: RDD[Row] = bigDataRDD.map(item => Row(item))
val structType: StructType = StructType(Array(
StructField("word", StringType, true)
))
val bigDataDF: DataFrame = spark.createDataFrame(bigDataRDDRow,structType)
bigDataDF.createOrReplaceTempView("bigDataTable")
spark.udf.register("computeLength",(input:String) => input.length)
//直接在SQL语句中使用UDF,就像使用SQL内置函数一样
spark.sql("select word,computeLength(word) as length from bigDataTable").show()
spark.udf.register("wordCount", new MyUDAF)
spark.sql("select word,computeLength(word) as length, wordCount(word) as count from bigDataTable group by word").show()
sc.stop()
spark.stop()
}
}
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
class MyUDAF extends UserDefinedAggregateFunction{
//该方法指定具体输入数据类型
override def inputSchema: StructType = StructType(Array(StructField("input", StringType, true)))
//在进行聚合操作的时候所要处理的数据的结果的类型
override def bufferSchema: StructType = StructType(Array(StructField("count", IntegerType, true)))
//返回的数据类型
override def dataType: DataType = IntegerType
//确保结果一致性
override def deterministic: Boolean = true
//在Aggregate之前每组数据的初始化结果
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
}
//在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
//本地的聚合,相当于Hadood MapReduce模型中的Combiner
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + 1
}
//最后在分布式节点进行Local Reduce完成后需要进行全局级别的Merge操作
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
override def evaluate(buffer: Row): Any = buffer.getAs[Int](0)
}