HashingTF无法获得词索引关系,所以tf需要替换成CountVectorizer,具体看代码
from pyspark.ml.feature import HashingTF, IDF, Tokenizer
from pyspark.ml.feature import CountVectorizer
sentenceData = spark.createDataFrame([
(1, "Hi I heard about Spark Spark".split(" ")),
(2, "I wish Java could use case classes Spark".split(" ")),
(3, "Logistic regression regression models are neat".split(" "))
], ["id", "sentence"])
cv = CountVectorizer(inputCol="sentence", outputCol="words", vocabSize=30, minDF=1.0).fit(sentenceData)
featurizedData = cv.transform(sentenceData)
#存储词索引信息
dic={}
c = 0
for k in cv.vocabulary:
dic[str(c)] =k
c+=1
print("词索引信息:",dic)
def _prc_row(row):
f = row.features
indices = f.indices.tolist()
values = f.values.tolist()
kvs = {}
c = 0
for i in indices:
kvs[dic.get(str(i))] = values[c]
c+=1
return row.id,kvs
idf = IDF(inputCol="words", outputCol="features")
idfModel = idf.fit(featurizedData)
spark.sql("drop table if exists user_tags")
rescaledData = idfModel.transform(featurizedData)
rescaledData.rdd.map(_prc_row).toDF(['id','kvs']).createTempView("user_tags")
spark.sql("select id,k,v from user_tags lateral view explode(kvs) as k,v").show()
结果输出
词索引信息: {'0': 'Spark', '1': 'I', '2': 'regression', '3': 'are', '4': 'heard', '5': 'classes', '6': 'Java', '7': 'Logistic', '8': 'neat', '9': 'case', '10': 'models', '11': 'about', '12': 'could', '13': 'wish', '14': 'use', '15': 'Hi'}
+---+----------+-------------------+
| id| k| v|
+---+----------+-------------------+
| 1| about| 0.6931471805599453|
| 1| Hi| 0.6931471805599453|
| 1| I|0.28768207245178085|
| 1| Spark| 0.5753641449035617|
| 1| heard| 0.6931471805599453|
| 2| wish| 0.6931471805599453|
| 2| Java| 0.6931471805599453|
| 2| use| 0.6931471805599453|
| 2| could| 0.6931471805599453|
| 2| classes| 0.6931471805599453|
| 2| I|0.28768207245178085|
| 2| Spark|0.28768207245178085|
| 2| case| 0.6931471805599453|
| 3| models| 0.6931471805599453|
| 3| neat| 0.6931471805599453|
| 3| are| 0.6931471805599453|
| 3| Logistic| 0.6931471805599453|
| 3|regression| 1.3862943611198906|
+---+----------+-------------------+