原始数据如下:
gid | score |
---|---|
a1 | 90 80 79 80 |
a2 | 79 89 45 60 |
a3 | 57 56 89 75 |
from pyspark.sql.functions import udf, col
from pyspark.sql.types import MapType, IntegerType, StringType
def udf_array_to_map(array):
if array is None:
return array
return dict((i, v) for i, v in enumerate(array))
# col(): returns a column based on the given column name
# MapType: 表示包括一组key-value的值.通过keyType表示key数据的类型,通过valueType表示value数据的类型.
# 最后一个参数指明mapType重点值是否有null值
def generate_idx_for_df(df, id_name, col_name, col_schema):
"""
generate_idx_for_df, explodes rows with array as a column into a new row for each
element in the array, with 'INTEGER_IDX' indicating its index in the original array.
:param df: dataframe with array columns
:param id_name: the id field of df
:param col_name: the col of df to explode
:param col_schema: the schema of each element in col_name array
:return: new df with exploded rows.
"""
idx_udf = udf(lambda x: udf_array_to_map(x), MapType(IntegerType(), col_schema, True))
return df.withColumn('idx_columns', idx_udf(col(col_name))) \
.select(id_name, explode('idx_columns').alias('INTEGER_IDX', 'col'))
方法的主要思想是利用pyspark.sql.functions中的udf(用户自定义函数),对dataframe的每一行遍历并添加字典序
注意!!!udf的返回数据类型一定要是map否则默认为string类型,则后续explode操作会报错,如下:
gid | s | idx_columns |
---|---|---|
a1 | [90, 80, 79, 80] | {0=90, 1=80, 2=79... |
a2 | [79, 89, 45, 60] | {0=79, 1=89, 2=45... |
a3 | [57, 56, 89, 75] | {0=57, 1=56, 2=89... |
org.apache.spark.sql.AnalysisException: cannot resolve 'explode(idx_columns
)' due to data type mismatch: input to function explode should be array or map type, not StringType;
正确的中间结果应该如下所示:
gid | s | idx_columns |
---|---|---|
a1 | [90, 80, 79, 80] | Map(0 -> 90, 1 ->... |
a2 | [79, 89, 45, 60] | Map(0 -> 79, 1 ->... |
a3 | [57, 56, 89, 75] | Map(0 -> 57, 1 ->... |
from pyspark.sql.functions import split, explode
df_split = df.withColumn("s", split(df['score'], " ")).select('gid', 's')
df_split.show()
col_schema = StringType()
df_index = generate_idx_for_df(df_split, 'gid', 's', col_schema)
df_index.show()
最后分割完成后的结果如下所示 :
gid | INTEGER_IDX | col |
---|---|---|
a1 | 0 | 90 |
a1 | 1 | 80 |
a1 | 2 | 79 |
a1 | 3 | 80 |
a2 | 0 | 79 |
a2 | 1 | 89 |
a2 | 2 | 45 |
a2 | 3 | 60 |
a3 | 0 | 57 |
a3 | 1 | 56 |
a3 | 2 | 89 |
a3 | 3 | 75 |
参考资料:https://www.programcreek.com/python/example/98237/pyspark.sql.functions.explode