逻辑回归 随机梯度下降 批梯度下降 二分类问题 不带正则项 java

/**
 * 逻辑回归 随机梯度下降 批梯度下降 二分类问题
 * @author Administrator
 *
 */
public class LR {
    List<Sample> samples;
    List<Double> paramers;
    static class Sample{
        Double label;
        List<Double> feature;
    }
    
    public  void loadData(String path,String regex) throws Exception{
        samples = new ArrayList<Sample>();
        BufferedReader reader = new BufferedReader(new FileReader(path));
        String line = null;
        String splits[] = null;
        Sample sample = null;
        while(null != (line=reader.readLine())){
            splits = line.split(regex);
            sample = new Sample();
            sample.feature = new ArrayList<Double>(splits.length);
            sample.feature.add(new Double(1));                      //偏置
            for(int i=0;i<splits.length-1;i++){
                sample.feature.add(new Double(splits[i]));
            }
            //sample.feature.set(splits.length-1,sample.feature.get(splits.length-1)/21.0);
            sample.label = Double.valueOf(splits[splits.length-1]);
            samples.add(sample);
        }
        reader.close();
    }
    public Double classify(Sample sample,List<Double> params){
        double ret = 0;
        double sum = 0;
        for(int index=0;index<params.size();index++){
            sum += sample.feature.get(index)*params.get(index);
        }
        
        return 1/(1+Math.exp(-sum));
    }
    public void updateParam(Sample sample,List<Double> params,double eta){
        double target = classify(sample,params);
        double diff = target-sample.label;
        double value = 0;
        for(int i=0;i<params.size();i++){
            value = params.get(i)-eta*diff*sample.feature.get(i);
            params.set(i,value);
        }
    }
    public void updateParamByAll(List<Sample> samples,List<Double> params,double eta){
        double sums[] = new double[params.size()];
        for(Sample sample:samples){
            double target = classify(sample,params);
            double diff = target-sample.label;
            for(int i=0;i<params.size();i++){
                sums[i] += diff*sample.feature.get(i);
            }
        }
        double value = 0;
        for(int i=0;i<params.size();i++){
            value = params.get(i)-eta*sums[i];
            params.set(i,value);
        }
    }
    public void train(int iters,double eta){
        int param_len = samples.get(0).feature.size();
        paramers = new ArrayList<Double>(param_len);
        for(int i=0;i<param_len;i++){
            paramers.add(new Double(0));
        }
        for(int i=0;i<iters;i++){
            for(Sample sample:samples){
                updateParam(sample,paramers,eta);
            }
            
            //updateParamByAll(samples,paramers,eta);
        }
        for(Double param:paramers){
            System.out.print(param+"  ");
        }
        System.out.println();
    }
    
    public void test(){
        int count = 0;
        for(Sample sample:samples){
            double value = classify(sample,paramers);
            System.out.println(value+","+sample.label);
            if(sample.label>0){
                if(value>=0.5){
                    count++;
                }
            }else{
                if(value<0.5){
                    count++;
                }
            }
        }
        
        System.out.println("right rate: "+count*1.0/samples.size());
    }
    public static void main(String argv[]) throws Exception{
        LR lr = new LR();
        lr.loadData("F:/contest/iris.csv",",");
        lr.train(1000,0.01);
        lr.test();
    }
}

猜你喜欢

转载自blog.csdn.net/ysh126/article/details/53073761