/**
* 逻辑回归 随机梯度下降 批梯度下降 二分类问题
* @author Administrator
*
*/
public class LR {
List<Sample> samples;
List<Double> paramers;
static class Sample{
Double label;
List<Double> feature;
}
public void loadData(String path,String regex) throws Exception{
samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Double>(splits.length);
sample.feature.add(new Double(1)); //偏置
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i]));
}
//sample.feature.set(splits.length-1,sample.feature.get(splits.length-1)/21.0);
sample.label = Double.valueOf(splits[splits.length-1]);
samples.add(sample);
}
reader.close();
}
public Double classify(Sample sample,List<Double> params){
double ret = 0;
double sum = 0;
for(int index=0;index<params.size();index++){
sum += sample.feature.get(index)*params.get(index);
}
return 1/(1+Math.exp(-sum));
}
public void updateParam(Sample sample,List<Double> params,double eta){
double target = classify(sample,params);
double diff = target-sample.label;
double value = 0;
for(int i=0;i<params.size();i++){
value = params.get(i)-eta*diff*sample.feature.get(i);
params.set(i,value);
}
}
public void updateParamByAll(List<Sample> samples,List<Double> params,double eta){
double sums[] = new double[params.size()];
for(Sample sample:samples){
double target = classify(sample,params);
double diff = target-sample.label;
for(int i=0;i<params.size();i++){
sums[i] += diff*sample.feature.get(i);
}
}
double value = 0;
for(int i=0;i<params.size();i++){
value = params.get(i)-eta*sums[i];
params.set(i,value);
}
}
public void train(int iters,double eta){
int param_len = samples.get(0).feature.size();
paramers = new ArrayList<Double>(param_len);
for(int i=0;i<param_len;i++){
paramers.add(new Double(0));
}
for(int i=0;i<iters;i++){
for(Sample sample:samples){
updateParam(sample,paramers,eta);
}
//updateParamByAll(samples,paramers,eta);
}
for(Double param:paramers){
System.out.print(param+" ");
}
System.out.println();
}
public void test(){
int count = 0;
for(Sample sample:samples){
double value = classify(sample,paramers);
System.out.println(value+","+sample.label);
if(sample.label>0){
if(value>=0.5){
count++;
}
}else{
if(value<0.5){
count++;
}
}
}
System.out.println("right rate: "+count*1.0/samples.size());
}
public static void main(String argv[]) throws Exception{
LR lr = new LR();
lr.loadData("F:/contest/iris.csv",",");
lr.train(1000,0.01);
lr.test();
}
}
* 逻辑回归 随机梯度下降 批梯度下降 二分类问题
* @author Administrator
*
*/
public class LR {
List<Sample> samples;
List<Double> paramers;
static class Sample{
Double label;
List<Double> feature;
}
public void loadData(String path,String regex) throws Exception{
samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Double>(splits.length);
sample.feature.add(new Double(1)); //偏置
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i]));
}
//sample.feature.set(splits.length-1,sample.feature.get(splits.length-1)/21.0);
sample.label = Double.valueOf(splits[splits.length-1]);
samples.add(sample);
}
reader.close();
}
public Double classify(Sample sample,List<Double> params){
double ret = 0;
double sum = 0;
for(int index=0;index<params.size();index++){
sum += sample.feature.get(index)*params.get(index);
}
return 1/(1+Math.exp(-sum));
}
public void updateParam(Sample sample,List<Double> params,double eta){
double target = classify(sample,params);
double diff = target-sample.label;
double value = 0;
for(int i=0;i<params.size();i++){
value = params.get(i)-eta*diff*sample.feature.get(i);
params.set(i,value);
}
}
public void updateParamByAll(List<Sample> samples,List<Double> params,double eta){
double sums[] = new double[params.size()];
for(Sample sample:samples){
double target = classify(sample,params);
double diff = target-sample.label;
for(int i=0;i<params.size();i++){
sums[i] += diff*sample.feature.get(i);
}
}
double value = 0;
for(int i=0;i<params.size();i++){
value = params.get(i)-eta*sums[i];
params.set(i,value);
}
}
public void train(int iters,double eta){
int param_len = samples.get(0).feature.size();
paramers = new ArrayList<Double>(param_len);
for(int i=0;i<param_len;i++){
paramers.add(new Double(0));
}
for(int i=0;i<iters;i++){
for(Sample sample:samples){
updateParam(sample,paramers,eta);
}
//updateParamByAll(samples,paramers,eta);
}
for(Double param:paramers){
System.out.print(param+" ");
}
System.out.println();
}
public void test(){
int count = 0;
for(Sample sample:samples){
double value = classify(sample,paramers);
System.out.println(value+","+sample.label);
if(sample.label>0){
if(value>=0.5){
count++;
}
}else{
if(value<0.5){
count++;
}
}
}
System.out.println("right rate: "+count*1.0/samples.size());
}
public static void main(String argv[]) throws Exception{
LR lr = new LR();
lr.loadData("F:/contest/iris.csv",",");
lr.train(1000,0.01);
lr.test();
}
}