Java实现NaiveBayes

用Java写了NaiveByes分类,只实现了二项式模型。

其实就是用到了spark的分布计算,计算每个特征下给分类的数据计算概率,然后取最大的。

pi[]就是 公式中的logP(B),theta[][] 是logP(A|B),然后计算P(B)*P(A|B)的最大值,取对数,就是pi[i] + sum((计算数据Vector)*theta[i]),结果去最大的,代码中用的是矩阵和向量的乘积。

代码,

import java.io.Serializable;
import java.util.List;
import java.util.function.Consumer;


import org.apache.spark.mllib.classification.NaiveBayes$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;


import scala.Tuple2;


/**
 * 
 */


public class JavaNaiveBayes
{
private double lambda;
private String modelType;

public JavaNaiveBayes()
{
this(1.0);
}
private JavaNaiveBayes( double lambda, String modelType )
{
super( );
this.lambda = lambda;
this.modelType = modelType;
}

public JavaNaiveBayes( double lambda )
{
this(lambda, NaiveBayes$.MODULE$.Multinomial( ));
}

private JavaNaiveBayesModel run(RDD<LabeledPoint> data)
{
List<Tuple2<Double, Tuple2<Long, Vector>>> aggregated = data.toJavaRDD( ).mapToPair( t-> new Tuple2<Double, Vector>(t.label( ), t.features( ))).
combineByKey(v1->new Tuple2<Long, Vector>(1L, v1),
(v1, v2)->
{
BLAS.axpy( 1.0, v2, v1._2 );
return new Tuple2<Long, Vector>(v1._1 + 1, v1._2);
},
(v1,v2)->
{
BLAS.axpy( 1.0, v1._2, v2._2 );
return new Tuple2<Long, Vector>( v1._1 + v2._1, v2._2 );
}).sortByKey( ).collect( );

int numLabels = aggregated.size( );
int numFeatures = aggregated.get( 0 )._2._2( ).size( );

DocumnetConsumer consumer = new DocumnetConsumer( );
aggregated.forEach( consumer );
long numDocument = consumer.getNumDocument( );

double[] labels = new double[numLabels];
double[] pi = new double[numLabels];

double[][] theta = new double[numLabels][numFeatures];
double piLogDenom = Math.log( numDocument + numLabels*lambda );
CaleConsumer caleConsumer = new CaleConsumer( labels, pi, theta, piLogDenom, lambda, numFeatures );
aggregated.forEach( caleConsumer );
return new JavaNaiveBayesModel(labels, pi, theta);
}

private static class DocumnetConsumer implements Consumer<Tuple2<Double, Tuple2<Long, Vector>>>,Serializable
{
private long numDocument = 0;

public long getNumDocument( )
{
return numDocument;
}
@Override
public void accept( Tuple2<Double, Tuple2<Long, Vector>> t )
{
numDocument = numDocument + t._2( )._1;
}
}

private static class CaleConsumer implements Consumer<Tuple2<Double, Tuple2<Long, Vector>>>
{


double[] labels ;
double[] pi ;

double[][] theta ;
double piLogDenom;
private int i = 0;
private double lambda;
private int numFeatures;

public CaleConsumer( double[] labels, double[] pi, double[][] theta, double piLogDenom, double lambda, int numFeatures )
{
super( );
this.labels = labels;
this.pi = pi;
this.theta = theta;
this.piLogDenom = piLogDenom;
this.lambda = lambda;
this.numFeatures = numFeatures;
}


@Override
public void accept( Tuple2<Double, Tuple2<Long, Vector>> t )
{
double label = t._1;
long n = t._2._1;
Vector sumTermFreqs = t._2._2;

labels[i] = label;
pi[i] = Math.log( n + lambda ) - piLogDenom;
double thetaLogDenom = Math.log( sumVector( sumTermFreqs ) + numFeatures*lambda );

for (int j=0; j<numFeatures; j++)
{
theta[i][j] = Math.log( sumTermFreqs.apply( j ) + lambda ) - thetaLogDenom;
}
i++;
}
}

private static double sumVector(Vector v)
{
double[] vs = v.toArray( );
double retValue = 0.0;
for (double d:vs)
{
retValue = retValue + d;
}

return retValue;
}

public static JavaNaiveBayesModel train(RDD<LabeledPoint> input)
{
return new JavaNaiveBayes( ).run( input );
}
}


import java.util.ArrayList;
import java.util.List;


import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.mllib.classification.ClassificationModel;
import org.apache.spark.mllib.classification.NaiveBayes$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.DenseMatrix;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import scala.reflect.ClassManifestFactory;


/**
 * 
 */


public class JavaNaiveBayesModel implements ClassificationModel
{


private double[] labels;
private double[] pi;
private double[][] theta;
private String modelType;
private Vector piVector;
private Matrix thetaMatrix;


public JavaNaiveBayesModel( double[] labels, double[] pi, double[][] theta)
{
this(labels,pi, theta,NaiveBayes$.MODULE$.Multinomial( ));
}
private JavaNaiveBayesModel( double[] labels, double[] pi, double[][] theta, String modelType )
{
super( );
this.labels = labels;
this.pi = pi;
this.theta = theta;
this.modelType = modelType;


piVector = Vectors.dense( pi );
thetaMatrix = new DenseMatrix( labels.length, theta[0].length, flattenDoubles( theta ), true );
}


private static double[] flattenDoubles( double[][] values )
{
int len = values[0].length;
double[] retValue = new double[values.length * len];
for ( int i = 0; i < values.length; i++ )
{
System.arraycopy( values[i], 0, retValue, i * len, len );
}
return retValue;
}



public double[] labels( )
{
return labels;
}



public String getModelType( )
{
return modelType;
}



public double[] pi( )
{
return pi;
}



public double[][] theta( )
{
return theta;
}


@Override
public RDD<Object> predict( RDD<Vector> t )
{
Broadcast<JavaNaiveBayesModel> bcModel = t.context( ).broadcast( this, ClassManifestFactory.classType( this.getClass( ) ) );

return t.toJavaRDD( ).mapPartitions( s->
{
JavaNaiveBayesModel model = bcModel.getValue( );
List retValue = new ArrayList<>( );
while(s.hasNext( ))
{
retValue.add( model.predict( s.next( ) ) );
}
return retValue;
}).rdd( );
}


@Override
public double predict( Vector t )
{
return labels[multinomialCalculation( t ).argmax( )];
}


@Override
public JavaRDD<Double> predict( JavaRDD<Vector> t )
{
return predict( t.rdd( ) ).toJavaRDD( ).map( s -> (Double) s );
}


private Vector multinomialCalculation( Vector testData )
{
DenseVector prob = thetaMatrix.multiply( testData );
BLAS.axpy(1.0, piVector, prob);
return prob;
}


}


猜你喜欢

转载自blog.csdn.net/hhtop112408/article/details/78530119