SparkSQL中有两种自定函数,在我们使用自带的函数时无法满足自己的需求时,可以使用自定义函数,SparkSQL中有两种自定义函数,一种是UDF,另一种是UDAF,和Hive 很类似,但是hive中还有UDTF,一进多出,但是sparkSQL中没有,这是因为spark中用 flatMap这个函数,可以实现和udtf相同的功能
UDF函数是针对的是一进一出
UDAF针对的是多进一出
udf很简单,只需要注册一下,然后写一个函数,就可以在sql查询中使用了
df1.createTempView("user")
//注册
spark.udf.register("lengthStr",(str:String)=>str.length)//自定义函数
//直接在sql中就可以使用啦
val df2 = spark.sql("select lengthStr(name) from user")
udaf相对来说比较复杂一点,需要继承一个 UserDefinedAggregateFunction类,在重写其中的方法,自定义函数求平均值,详细的步骤在下面的代码中
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession, types}
object UDAFavg {
def main(args: Array[String]): Unit = {
val spark = SparkSession.builder().appName("avg").master("local").getOrCreate()
val sc = spark.sparkContext
val sqlContext = spark.sqlContext
val files: RDD[String] = sc.textFile("D:\\read\\teacher.txt")
val rowRDD: RDD[Row] = files.map(row => {
val split = row.split(" ")
Row(split(0), split(1),split(2).toLong)
})
/* rowRDD.foreach(row =>{
println(row.getString(0)+" "+row.getString(1)+row.get(2))
})*/
val structType = StructType(List(StructField("subject",StringType,true),StructField("tname",StringType,true),
StructField("age",LongType,true)))
val df1: DataFrame = spark.createDataFrame(rowRDD,structType)
df1.createTempView("teacher")
//注册函数, 自定义一个函数,实现求平均数
spark.udf.register("TeacherAvg",new UDAFavg)
//df1.show()
spark.sql("select subject,TeacherAvg(age) as avgAGE from teacher group by subject ").show()
}
}
//自定义UDAF函数
class UDAFavg extends UserDefinedAggregateFunction{
//输入数据类型,求平均值,所以数据类型是LongType(StructType中的类型)
override def inputSchema: StructType = {
StructType(List(StructField("age",LongType,true)))}
//中间结果的类型,这里定义了两个中间的类型,因为在求平均值时,首先一个存总的和,一个计算个数,最后的结果是两者相除
override def bufferSchema: StructType = {
StructType(List(StructField("age",LongType),StructField("count",LongType)))}
//输出返回类型
override def dataType: DataType = {LongType}
//是否数据同一性,一般都是true
override def deterministic: Boolean = true
//初始化定义两个中间值
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//类型要和上面定义的位置相对应
buffer(0) = 0L //初始化 总和
buffer(1) = 0L // 个数
}
//进行计算,
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //input是每次的输入Row类型
buffer(1) = buffer.getAs[Long](1)+ 1 //个数 每次加1
buffer(0) = buffer.getAs[Long](0) + input.getLong(0)
// 把每个传的值进行累加
}
//有可能有多个分区,多个task ,总后把进行合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(1) = buffer1.getAs[Long](1)+ buffer2.getAs[Long](1)//多台机器中的count的值进行相加
buffer1(0) = buffer1.getAs[Long](0) + buffer2.getLong(0)
}
//返回的最终结果
override def evaluate(buffer: Row): Any = {
buffer.getAs[Long](0) / buffer.getAs[Long](1)
}
}