项目github地址:bitcarmanlee easy-algorithm-interview-and-practice
经常有同学私信或留言询问相关问题,V号bitcarmanlee。github上star的同学,在我能力与时间允许范围内,尽可能帮大家解答相关问题,一起进步。
1.为什么需要窗口函数
在1.4以前,Spark SQL支持两种类型的函数用来计算单个的返回值。第一种是内置函数或者UDF函数,他们将单个行中的值作为输入,并且他们为每个输入行生成单个返回值。另外一种是聚合函数,典型的是SUM, MAX, AVG这种,是对一组行数据进行操作,并且为每个组计算一个返回值。
上面提到的两种函数,实际当中使用非常广泛,但是仍然存在大量无法单独使用这些类型的函数来表达的操作。最常见的一种场景就是,很多时候需要对一组行进行操作,而仍然为每个输入行返回一个值,上面的两种方法就无能为力。例如对于计算移动平均值,计算累计和或访问出现在当前行之前的行的值等,就显得非常困难。幸运的是,在1.4以后的版本,Spark SQL就提供了窗口函数来弥补上面的不足。
窗口函数的核心是“Frame”,或者我们直接称呼其为帧,帧就是一系列的多行数据,或者说许多分组。然后我们可以基本这些分组来满足上面普通函数无法完成的功能。为了看清楚其具体的应用,我们直接看例子。Talk is cheap, Show me the code.
2.构造数据集
为了方便测试,我们首先构造数据集
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions._
def test() = {
val sparkConf = new SparkConf().setMaster("local[2]")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val data = Array(("lili", "ml", 90),
("lucy", "ml", 85),
("cherry", "ml", 80),
("terry", "ml", 85),
("tracy", "cs", 82),
("tony", "cs", 86),
("tom", "cs", 75))
val schemas = Seq("name", "subject", "score")
val df = spark.createDataFrame(data).toDF(schemas: _*)
df.show()
}
将上面的test方法本地run起来以后,输出如下
+------+-------+-----+
| name|subject|score|
+------+-------+-----+
| lili| ml| 90|
| lucy| ml| 85|
|cherry| ml| 80|
| terry| ml| 85|
| tracy| cs| 82|
| tony| cs| 86|
| tom| cs| 75|
+------+-------+-----+
数据构造完毕
3.分组查看排名
经常用到的一个场景是:需要查看每个专业学生的排名,这就是一个典型的分组问题,就是窗口函数大显身手的时候。
一个窗口需要定义三个部分:
1.分组问题,如何将行分组?在选取窗口数据时,只对组内数据生效
2.排序问题,按何种方式进行排序?选取窗口数据时,会首先按指定方式排序
3.帧(frame)选取,以当前行为基准,如何选取周围行?
对照上面的三个部分,窗口函数的语法一般为:
window_func(args) OVER ( [PARTITION BY col_name, col_name, ...] [ORDER BY col_name, col_name, ...] [ROWS | RANGE BETWEEN (CURRENT ROW | (UNBOUNDED |[num]) PRECEDING) AND (CURRENT ROW | ( UNBOUNDED | [num]) FOLLOWING)] )
其中
window_func就是窗口函数
over表示这是个窗口函数
partition by对应的就是分组,即按照什么列分组
order by对应的是排序,按什么列排序
rows则对应的帧选取。
spark中的window_func包括下面三类:
1.排名函数(ranking function) 包括rank,dense_rank, row_number,percent_rank, ntile等,后面我们结合例子来看。
2.分析函数 (analytic functions) 包括cume_dist,lag等。
3.聚合函数(aggregate functions),就是我们常用的max, min, sum, avg等。
回到上面的需求,查看每个专业学生的排名
def test() = {
val sparkConf = new SparkConf().setMaster("local[2]")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val sqlContext = spark.sqlContext
val data = Array(("lili", "ml", 90),
("lucy", "ml", 85),
("cherry", "ml", 80),
("terry", "ml", 85),
("tracy", "cs", 82),
("tony", "cs", 86),
("tom", "cs", 75))
val schemas = Seq("name", "subject", "score")
val df = spark.createDataFrame(data).toDF(schemas: _*)
df.createOrReplaceTempView("person_subject_score")
val sqltext = "select name, subject, score, rank() over (partition by subject order by score desc) as rank from person_subject_score";
val ret = sqlContext.sql(sqltext)
ret.show()
}
上面的代码run起来,结果如下
+------+-------+-----+----+
| name|subject|score|rank|
+------+-------+-----+----+
| tony| cs| 86| 1|
| tracy| cs| 82| 2|
| tom| cs| 75| 3|
| lili| ml| 90| 1|
| lucy| ml| 85| 2|
| terry| ml| 85| 2|
|cherry| ml| 80| 4|
+------+-------+-----+----+
重点看下窗口部分:
rank() over (partition by subject order by score desc) as rank
rank()函数表示取每行在分组中的排名,partition by subject表示按subject分组,order by score desc表示按分数排序并且逆序,这样就可以得到每个学生在本专业中的排名!
row_number, dense_rank也都是排序有关的窗口函数,下面我们通过实例看看他们的区别:
def test() = {
val sparkConf = new SparkConf().setMaster("local[2]")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val sqlContext = spark.sqlContext
val data = Array(("lili", "ml", 90),
("lucy", "ml", 85),
("cherry", "ml", 80),
("terry", "ml", 85),
("tracy", "cs", 82),
("tony", "cs", 86),
("tom", "cs", 75))
val schemas = Seq("name", "subject", "score")
val df = spark.createDataFrame(data).toDF(schemas: _*)
df.createOrReplaceTempView("person_subject_score")
val sqltext = "select name, subject, score, rank() over (partition by subject order by score desc) as rank from person_subject_score";
val ret = sqlContext.sql(sqltext)
ret.show()
val sqltext2 = "select name, subject, score, row_number() over (partition by subject order by score desc) as row_number from person_subject_score";
val ret2 = sqlContext.sql(sqltext2)
ret2.show()
val sqltext3 = "select name, subject, score, dense_rank() over (partition by subject order by score desc) as dense_rank from person_subject_score";
val ret3 = sqlContext.sql(sqltext3)
ret3.show()
}
+------+-------+-----+----+
| name|subject|score|rank|
+------+-------+-----+----+
| tony| cs| 86| 1|
| tracy| cs| 82| 2|
| tom| cs| 75| 3|
| lili| ml| 90| 1|
| lucy| ml| 85| 2|
| terry| ml| 85| 2|
|cherry| ml| 80| 4|
+------+-------+-----+----+
+------+-------+-----+----------+
| name|subject|score|row_number|
+------+-------+-----+----------+
| tony| cs| 86| 1|
| tracy| cs| 82| 2|
| tom| cs| 75| 3|
| lili| ml| 90| 1|
| lucy| ml| 85| 2|
| terry| ml| 85| 3|
|cherry| ml| 80| 4|
+------+-------+-----+----------+
+------+-------+-----+----------+
| name|subject|score|dense_rank|
+------+-------+-----+----------+
| tony| cs| 86| 1|
| tracy| cs| 82| 2|
| tom| cs| 75| 3|
| lili| ml| 90| 1|
| lucy| ml| 85| 2|
| terry| ml| 85| 2|
|cherry| ml| 80| 3|
+------+-------+-----+----------+
通过上面的例子不难看出这三者的区别:
rank生成不连续的序号,上面的例子是1,2,2,4这种
dense_rank生成连续的序号,上面的例子是1,2,2,3这种
row_number顾名思义,生成的是行号,上面的例子是1,2,3,4这种。
不用去死抠函数的定义,看上面的例子就明白了!
4.查看分位数
下面再看个实例,我们想查看某个人在该专业的分位数,该怎么办?
这个时候就可以用到cume_dist函数了。
该函数的计算方式为:组内小于等于当前行值的行数/组内总行数
还是看代码
val sqltext5 = "select name, subject, score, cume_dist() over (partition by subject order by score desc) as cumedist from person_subject_score";
val ret5 = sqlContext.sql(sqltext5)
ret5.show()
结合前面的数据初始化代码与上面的sql逻辑,最后的结果如下:
+------+-------+-----+------------------+
| name|subject|score| cumedist|
+------+-------+-----+------------------+
| tony| cs| 86|0.3333333333333333|
| tracy| cs| 82|0.6666666666666666|
| tom| cs| 75| 1.0|
| lili| ml| 90| 0.25|
| lucy| ml| 85| 0.75|
| terry| ml| 85| 0.75|
|cherry| ml| 80| 1.0|
+------+-------+-----+------------------+
可以看到完美满足上面的需求。
5.使用DataFrame的API完成窗口查询
上面的例子使用的是SqlContext的API,在DataFrame中,也有对应的API可以完成查询,具体方式也很简单,使用DataFrame API在支持的函数调用over()方法即可,例如rank().over(…)
拿前面的需求为例,如果我们想查看学生在专业的排名,使用DataFrame的API如下:
def test() = {
val sparkConf = new SparkConf().setMaster("local[2]")
val spark = SparkSession.builder().config(sparkConf).getOrCreate()
val data = Array(("lili", "ml", 90),
("lucy", "ml", 85),
("cherry", "ml", 80),
("terry", "ml", 85),
("tracy", "cs", 82),
("tony", "cs", 86),
("tom", "cs", 75))
val schemas = Seq("name", "subject", "score")
val df = spark.createDataFrame(data).toDF(schemas: _*)
df.createOrReplaceTempView("person_subject_score")
val window = Window.partitionBy("subject").orderBy(col("score").desc)
val df2 = df.withColumn("rank", rank().over(window))
df2.show()
}
输出结果如下:
+------+-------+-----+----+
| name|subject|score|rank|
+------+-------+-----+----+
| tony| cs| 86| 1|
| tracy| cs| 82| 2|
| tom| cs| 75| 3|
| lili| ml| 90| 1|
| lucy| ml| 85| 2|
| terry| ml| 85| 2|
|cherry| ml| 80| 4|
+------+-------+-----+----+