from pyspark.ml.classification import NaiveBayes
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
spark= SparkSession\
.builder \
.appName("dataFrame") \
.getOrCreate()
# Load training data
data = spark.read.format("libsvm") \
.load("/home/luogan/lg/softinstall/spark-2.2.0-bin-hadoop2.7/data/mllib/sample_libsvm_data.txt")
# Split the data into train and test
splits = data.randomSplit([0.6, 0.4], 1234)
train = splits[0]
test = splits[1]
# create the trainer and set its parameters
nb = NaiveBayes(smoothing=1.0, modelType="multinomial")
# train the model
model = nb.fit(train)
# select example rows to display.
predictions = model.transform(test)
predictions.show()
# compute accuracy on the test set
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction",
metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test set accuracy = " + str(accuracy))
+-----+--------------------+--------------------+-----------+----------+
|label| features| rawPrediction|probability|prediction|
+-----+--------------------+--------------------+-----------+----------+
| 0.0|(692,[95,96,97,12...|[-174115.98587057...| [1.0,0.0]| 0.0|
| 0.0|(692,[98,99,100,1...|[-178402.52307196...| [1.0,0.0]| 0.0|
| 0.0|(692,[100,101,102...|[-100905.88974016...| [1.0,0.0]| 0.0|
| 0.0|(692,[123,124,125...|[-244784.29791241...| [1.0,0.0]| 0.0|
| 0.0|(692,[123,124,125...|[-196900.88506109...| [1.0,0.0]| 0.0|
| 0.0|(692,[124,125,126...|[-238164.45338794...| [1.0,0.0]| 0.0|
| 0.0|(692,[124,125,126...|[-184206.87833381...| [1.0,0.0]| 0.0|
| 0.0|(692,[127,128,129...|[-214174.52863813...| [1.0,0.0]| 0.0|
| 0.0|(692,[127,128,129...|[-182844.62193963...| [1.0,0.0]| 0.0|
| 0.0|(692,[128,129,130...|[-246557.10990301...| [1.0,0.0]| 0.0|
| 0.0|(692,[152,153,154...|[-208282.08496711...| [1.0,0.0]| 0.0|
| 0.0|(692,[152,153,154...|[-243457.69885665...| [1.0,0.0]| 0.0|
| 0.0|(692,[153,154,155...|[-260933.50931276...| [1.0,0.0]| 0.0|
| 0.0|(692,[154,155,156...|[-220274.72552901...| [1.0,0.0]| 0.0|
| 0.0|(692,[181,182,183...|[-154830.07125175...| [1.0,0.0]| 0.0|
| 1.0|(692,[99,100,101,...|[-145978.24563975...| [0.0,1.0]| 1.0|
| 1.0|(692,[100,101,102...|[-147916.32657832...| [0.0,1.0]| 1.0|
| 1.0|(692,[123,124,125...|[-139663.27471685...| [0.0,1.0]| 1.0|
| 1.0|(692,[124,125,126...|[-129013.44238751...| [0.0,1.0]| 1.0|
| 1.0|(692,[125,126,127...|[-81829.799906049...| [0.0,1.0]| 1.0|
+-----+--------------------+--------------------+-----------+----------+
only showing top 20 rows
Test set accuracy = 1.0