UDAF的特点就是:N:1,目的就是为了做聚合(group by)
UserDefinedAggregateFunction是用户自定义聚合函数要继承的抽象类,传参---->initialize初始化、update(RDD分区内部的合并)、merge(分区之间总结果之间的合并)
class MyUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
DataTypes.createStructType(Array(DataTypes.createStructField("type",StringType,true)))
}
// 聚合操作时,中间更新,所处理的数据的类型
override def bufferSchema: StructType = {
DataTypes.createStructType(Array(DataTypes.createStructField("type",IntegerType,true)))
}
// 最终函数返回值的类型
override def dataType: DataType = {
DataTypes.IntegerType
}
//多次运行 相同的输入总是相同的输出,确保一致性
override def deterministic: Boolean = {
true
}
/**
* 为每个分组的数据执行初始化值
* 两个部分的初始化:
* 1.在map端每个RDD分区内,在RDD每个分区内 按照group by 的字段分组,每个分组都有个初始化的值:a=0
* 2.在reduce 端给每个group by 的分组做初始值:a=0
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0
}
/**每个组,有新的值进来的时候,进行分组对应的聚合值的计算
* input就是输入的数据,传进来name,放到Row中
* buffer就是上一步的初始值
*拿到0号位,来一个就+1
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getAs[Int](0)+1 //初始值+1
}
/**
* 最后merger的时候,在各个节点上的聚合值,要进行merge,也就是合并
*buffer1拿某一组的初始值:a=0 buffer2就是a1组的值:2 ---> 0+2=2
* 下一波,buffer1就是2,buffer2就是a2组的值:1 --->2+1=3
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0)=buffer1.getAs[Int](0)+buffer2.getAs[Int](0)
}
// 最后返回一个最终的聚合值要和dataType的类型一一对应
//buffer:每个组的总和
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
}
object UDAF {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder()
.appName("udaf")
.master("local")
.getOrCreate()
val list = List[String]("zhangsan","zhangsan","lisi","wangwu","wangwu","lisi","wangwu","lisi")
import spark.implicits._
val frame = list.toDF("name")
frame.createOrReplaceTempView("person")
//注册UDAF
spark.udf.register("nameCount",new MyUDAF)
spark.sql("select name,nameCount(name) as count from person group by name ").show()
}
}