一:UDAF含义
UDAF:User Defined Aggregate Function。用户自定义聚合函数
对比UDF:
UDF,其实更多的是针对单行输入,返回一个输出
UDAF,则可以针对多行输入,进行聚合计算,返回一个输出
二:关于UDAF的一个误区
我们可能下意识的认为UDAF是需要和group by一起使用的,实际上UDAF可以跟group by一起使用,也可以不跟group by一起使用,这个其实比较好理解,联想到mysql中的max、min等函数,可以:
1
select max(foo) from foobar group by bar;
表示根据bar字段分组,然后求每个分组的最大值,这时候的分组有很多个,使用这个函数对每个分组进行处理,也可以:
1
select max(foo) from foobar;
这种情况可以将整张表看做是一个分组,然后在这个分组(实际上就是一整张表)中求最大值。所以聚合函数实际上是对分组做处理,而不关心分组中记录的具体数量。
三:UDAF中:update,merge,evaluate方法的含义
update:各个分组的值内部聚合
merge:各个节点的同一分组的值聚合
evaluate:聚合各个分组的缓存值
四:自定义UDAF实战
定义:
/**
* @author Administrator
*/
class StringCount extends UserDefinedAggregateFunction {
// inputSchema,指的是,输入数据的类型
def inputSchema: StructType = {
StructType(Array(StructField("str", StringType, true)))
}
// bufferSchema,指的是,中间进行聚合时,所处理的数据的类型
def bufferSchema: StructType = {
StructType(Array(StructField("count", IntegerType, true)))
}
// dataType,指的是,函数返回值的类型
def dataType: DataType = {
IntegerType
}
def deterministic: Boolean = {
true
}
// 为每个分组的数据执行初始化操作
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0
}
/**
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 大聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
* update的结果写入buffer中,每个分组中的每一行数据都要进行update操作
*/
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + 1
}
/**
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
* 也可以是一个节点里面的多个executor合并 reduce端大聚合
* merge后的结果写如buffer1中
*/
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
}
// 最后,指的是,一个分组的聚合值,如何通过中间的缓存聚合值,最后返回一个最终的聚合值
def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0)
}
使用:
object UDAF {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.setMaster("local")
.setAppName("UDAF")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
// 构造模拟数据
val names = Array("Leo", "Marry", "Jack", "Tom", "Tom", "Tom", "Leo")
val namesRDD = sc.parallelize(names, 5)
val namesRowRDD = namesRDD.map { name => Row(name) }
val structType = StructType(Array(StructField("name", StringType, true)))
val namesDF = sqlContext.createDataFrame(namesRowRDD, structType)
// 注册一张names表
namesDF.registerTempTable("names")
// 定义和注册自定义函数
// 定义函数:自己写匿名函数
// 注册函数:SQLContext.udf.register()
sqlContext.udf.register("strCount", new StringCount)
// 使用自定义函数
sqlContext.sql("select name,strCount(name) from names group by name")
.collect()
.foreach(println)
}
/* 结果:
* [Jack,1]
[Tom,3]
[Leo,2]
[Marry,1]
*/