方差计算工具类--Java版

方差的计算,如果不太计较精度的话,可以使用 Apache 的 commons-math3(http://commons.apache.org/proper/commons-math/)提供的 Variance 类。不过毕竟 Variance 是使用 double 进行计算,会有精度损失,所以自己写了个计算方差的工具类,采用 BigDecimal 进行计算,并且可以自己指定精度值,代码如下:

该工具类是使用最基本的方差计算公式进行计算的,如果要使用流式方式计算方差,可以参考以下文章:

https://zhuanlan.zhihu.com/p/48025855

​package com.frank.test.variance;

import java.math.BigDecimal;

import com.google.common.base.Preconditions;

/**
 * 方差计算工具类。
 * 
 * @author frank
 */
public final class VarianceUtils {
	/**
	 * 默认精度
	 */
	private static final int DEFAULT_SCALE = 64;
	private VarianceUtils() {}
	
	public static BigDecimal variance(byte[] arr) {
		Preconditions.checkNotNull(arr);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr);
	}
	
	public static BigDecimal variance(byte[] arr, int scale) {
		Preconditions.checkNotNull(arr);
		Preconditions.checkArgument(scale > 0, "scale must be positive: " + scale);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr, scale);
	}
	
	public static BigDecimal variance(char[] arr) {
		Preconditions.checkNotNull(arr);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr);
	}
	
	public static BigDecimal variance(char[] arr, int scale) {
		Preconditions.checkNotNull(arr);
		Preconditions.checkArgument(scale > 0, "scale must be positive: " + scale);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr, scale);
	}
	
	public static BigDecimal variance(int[] arr) {
		Preconditions.checkNotNull(arr);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr);
	}
	
	public static BigDecimal variance(int[] arr, int scale) {
		Preconditions.checkNotNull(arr);
		Preconditions.checkArgument(scale > 0, "scale must be positive: " + scale);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr, scale);
	}
	
	public static BigDecimal variance(long[] arr) {
		Preconditions.checkNotNull(arr);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr);
	}
	
	public static BigDecimal variance(long[] arr, int scale) {
		Preconditions.checkNotNull(arr);
		Preconditions.checkArgument(scale > 0, "scale must be positive: " + scale);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr, scale);
	}
	
	public static BigDecimal variance(float[] arr) {
		Preconditions.checkNotNull(arr);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr);
	}
	
	public static BigDecimal variance(float[] arr, int scale) {
		Preconditions.checkNotNull(arr);
		Preconditions.checkArgument(scale > 0, "scale must be positive: " + scale);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr, scale);
	}
	
	public static BigDecimal variance(double[] arr) {
		Preconditions.checkNotNull(arr);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr);
	}
	
	public static BigDecimal variance(double[] arr, int scale) {
		Preconditions.checkNotNull(arr);
		Preconditions.checkArgument(scale > 0, "scale must be positive: " + scale);
		
		String[] strArr = new String[arr.length];
		for (int i = 0; i < arr.length; i++) {
			strArr[i] = String.valueOf(arr[i]);
		}
		return variance(strArr, scale);
	}
	
	private static BigDecimal variance(String[] arr) {
		return variance(arr, DEFAULT_SCALE);
	}
	
	private static BigDecimal variance(String[] arr, int scale) {
        if (arr.length < 2) {
            return BigDecimal.ZERO;
        }

		BigDecimal sum = BigDecimal.ZERO;
        for (int i = 0; i < arr.length; i++) {
            sum = sum.add(new BigDecimal(arr[i]));
        }
        BigDecimal meanNum = sum.divide(new BigDecimal(arr.length), scale, BigDecimal.ROUND_HALF_DOWN);
        
        BigDecimal tmp = null;
        BigDecimal tmpSum = BigDecimal.ZERO;
        for (int i = 0; i < arr.length; i++) {
            tmp = meanNum.subtract(new BigDecimal(arr[i]));
            tmpSum = tmpSum.add(tmp.multiply(tmp));
        }
        BigDecimal vari = tmpSum.divide(new BigDecimal(arr.length - 1), scale, BigDecimal.ROUND_HALF_DOWN);
        return new BigDecimal(trimZero(vari.toString()));
	}
	
	/**
	 * 去除小数中后面多余的0
	 * 
	 * @param str
	 * @return
	 */
	private static String trimZero(String str) {
        if (!str.contains(".")) {
            return str;
        }

        StringBuilder ret = new StringBuilder();
        char[] chars = str.toCharArray();
        // stop trimming 0
        boolean stopTrim = false;
        for (int i = chars.length - 1; i >= 0; i--) {
            char ch = chars[i];
            if (stopTrim) {
                ret.append(ch);
                continue;
            }

            // not stop trimming 0
            if (ch != '0') {
                ret.append(ch);
                stopTrim = true;
            }
        }
        if (ret.charAt(0) == '.') {
            ret.deleteCharAt(0);
        }
        return ret.reverse().toString();
    }
}

​

猜你喜欢

转载自blog.csdn.net/frankingly/article/details/83900629