(1)算法描述
逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题,常见的是二分类或二项分布问题,也可以处理多分类问题,它实际上是属于一种分类方法
(2)测试数据
1 1:-0.222222 2:0.5 3:-0.762712 4:-0.833333
1 1:-0.555556 2:0.25 3:-0.864407 4:-0.916667
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-0.833333
1 1:-0.722222 2:0.166667 3:-0.694915 4:-0.916667
0 1:0.166667 2:-0.416667 3:0.457627 4:0.5
1 1:-0.833333 3:-0.864407 4:-0.916667
2 1:-1.32455e-07 2:-0.166667 3:0.220339 4:0.0833333
2 1:-1.32455e-07 2:-0.333333 3:0.0169491 4:-4.03573e-08
1 1:-0.5 2:0.75 3:-0.830508 4:-1
0 1:0.611111 3:0.694915 4:0.416667
0 1:0.222222 2:-0.166667 3:0.423729 4:0.583333
1 1:-0.722222 2:-0.166667 3:-0.864407 4:-1
1 1:-0.5 2:0.166667 3:-0.864407 4:-0.916667
2 1:-0.222222 2:-0.333333 3:0.0508474 4:-4.03573e-08
2 1:-0.0555556 2:-0.833333 3:0.0169491 4:-0.25
2 1:-0.166667 2:-0.416667 3:-0.0169491 4:-0.0833333
(3)测试代码
public class JavaMulticlassClassificationMetricsExample {
public static void main(String[] args) {
SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example");
SparkContext sc = new SparkContext(conf);
// $example on$
String path = "sample_multiclass_classification_data.txt";
JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();
// Split initial RDD into two... [60% training data, 40% testing data].
JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
JavaRDD<LabeledPoint> training = splits[0].cache();
JavaRDD<LabeledPoint> test = splits[1];
// Run training algorithm to build the model.
final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(3)
.run(training.rdd());
// Compute raw scores on the test set.
JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
new Function<LabeledPoint, Tuple2<Object, Object>>() {
public Tuple2<Object, Object> call(LabeledPoint p) {
Double prediction = model.predict(p.features());
return new Tuple2<Object, Object>(prediction, p.label());
}
}
);
System.out.println("--------------------->"+predictionAndLabels.take(10));
// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
// Confusion matrix
Matrix confusion = metrics.confusionMatrix();
System.out.println("Confusion matrix: \n" + confusion);
// Overall statistics
System.out.println("Precision = " + metrics.precision());
System.out.println("Recall = " + metrics.recall());
System.out.println("F1 Score = " + metrics.fMeasure());
// Stats by labels
for (int i = 0; i < metrics.labels().length; i++) {
System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision
(metrics.labels()[i]));
System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(metrics
.labels()[i]));
System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure
(metrics.labels()[i]));
}
//Weighted stats
System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());
// Save and load model
model.save(sc, "target/tmp/LogisticRegressionModel");
LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
"target/tmp/LogisticRegressionModel");
// $example off$
}
(4)测试结果
>[(1.0,1.0), (1.0,1.0), (0.0,0.0), (0.0,0.0), (1.0,1.0), (1.0,1.0), (2.0,2.0), (1.0,1.0), (2.0,2.0), (0.0,0.0)]