关键词:
- spark对象初始化
- transformer定义
- VectorAssembler特征合并
- pipeline训练
- 模型保存及测试
- 结果保存
from pyspark.sql import SparkSession
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
# 注意编码问题,python2.7下需要utf-8。python3下则不存在此问题
sentenceData = spark.createDataFrame([(0, u'我爱中华'),(0, u'王者荣耀是我的最爱'),(1, u'我讨厌学习'),(1,u'学习很糟糕')],['label', 'contentbody'])
# 自定义文本预处理、特征提取的方法
from pyspark.ml.feature import HashingTF,RegexTokenizer, NGram
regexTokenizer = RegexTokenizer(inputCol="contentbody", outputCol="words", pattern="")
ngram = NGram(n=2, inputCol='words', outputCol="ngrams")
hashingTF1 = HashingTF(inputCol="words", outputCol="rawFeatures1", numFeatures=100)
hashingTF2 = HashingTF(inputCol="ngrams", outputCol="rawFeatures2", numFeatures=100)
# 将pyspark的多个稀疏特征连起来,也可将稀疏特征、密集特征连起来
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=["rawFeatures1","rawFeatures2"],outputCol="features")
# 采用分类算法,迭代次数为10
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(maxIter=10, regParam=0.01)
# 使用管道,一个个transformer串起来,并训练模型
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[regexTokenizer, ngram, hashingTF1, hashingTF2, assembler, lr])
model = pipeline.fit(sentenceData)
# 模型保存
modelSavePath = '~/Desktop/sparkModel/'
model.save(modelSavePath)
modelLoad = LogisticRegression.load(modelSavePath)
# 不使用管道,详细看到每一步的结果
regexData = regexTokenizer.transform(sentenceData)
ngramData = ngram.transform(regexData)
hashtf1Data = hashingTF1.transform(ngramData)
hashtf2Data = hashingTF2.transform(hashtf1Data)
assemblerData = assembler.transform(hashtf2Data)
model2 = lr.fit(assemblerData)
regexData.show()
ngramData.show()
hashtf1Data.show()
hashtf2Data.show()
assemblerData.show()
# 测试
# 学习两个字在正例中都有,很可能分为正例,“爱,不存在的”中的“爱”、“的”都是在负例中,为此,肯定分为负例
testData = spark.createDataFrame([(0, u'爱学习么?'),(0, u'爱,不存在的')],['label', 'contentbody'])
output = model.transform(testData)
output.select(['contentbody','prediction']).show()
# 将结果保存
dfout = output.select(['contentbody','prediction'])
dfout.write.csv(path=outPath, header=False, sep="\t", mode='overwrite')
'''
+------------+----------+
| contentbody|prediction|
+------------+----------+
| 爱学习么?| 1.0|
|爱,不存在的| 0.0|
+------------+----------+
'''