在机器学习中,我们时常会碰到需要给属性增加字段的情况。譬如有x、y两个属性,当结果倾向于线性时,我们可以很简单的通过线性回归得到模型。但很多时候,线性(在数学上称为多元一次方程),线性是拟合不了结果的。
往往,我们就需要在给定的几个属性上,通过增加属性来尝试能否拟合。那么原本只有两列,x、y,我们增加2次方的属性后,就会变成x、y、x^2、x*y、y^2,变成了5个属性,根据以往经验,我们知道通过这5个属性是能拟合出曲线。
2次方时,我们还能很简单的写出来所有的组合形式,但是当5次方时,原本有4列时,我们该增加多少列,增加的列该怎么计算呢。这就有点麻烦了,譬如(x+y+z)^3展开后就是x^3+y^3+z^3+3xy^2+3xz^2+3x^2y+3yz^2+3x^2z+3y^2z+6xyz. 去掉系数后,就是我们需要追加的所有列了。我们这篇就是做一个程序,来通过给定的m列,n次方,来给出所有的组合形式。
譬如m为2,n也为2,那么我们给出结果组合:[{0,2}, {1,1}, {2,0}],代表追加3列,第一列是x^0 * y^2,第二列是x^1 * y^1,第三列是x^2 * y^0.
通过观察我们发现,我们需要做的是求这样的方程的所有解:X1+X2+X3+……+Xm = N。其中0<=X<=n。
那么解法就是,我们可以定义一个int[m],该数组共有m个元素,每个元素的取值范围在0到n之间,并且该数组的所有元素的和等于n即可。
直接看程序:
/** * @author wuweifeng wrote on 2018/6/4. */ public class LineAdder { private static int lines = 3; private static int power = 5; private static int[] resultArray; public static void main(String[] args) { resultArray = new int[lines]; deal(0); } public static void deal(int m) { for (int i = 0; i <= power; i++) { resultArray[m] = i; if (m == lines - 1) { //如果找到一个解 if (check()) { print(); return; } } else { deal(m + 1); } } } /** * 判断是否符合结果 * * @return 是否符合 */ private static boolean check() { int total = 0; for (int one : resultArray) { total += one; } return power == total; } private static void print() { for (int one : resultArray) { System.out.print(one); } System.out.print("\n"); } }
结果是:
005 014 023 032 041 050 104 113 122 131 140 203 212 221 230 302 311 320 401 410 500这就是有3列,并且希望求出5次方时的所有组合的答案。
下面我们将它优化一下,让他能处理文本,能处理一行一行的数据,直接把列追加在文本上。
直接上代码:
package ploy; import java.util.ArrayList; import java.util.List; /** * @author wuweifeng wrote on 2018/6/4. */ public class LineAdder { private int lines = 3; private int power = 5; private List<int[]> resultList = new ArrayList<>(); private int[] resultArray; public List<int[]> lineAdd(int lines, int power) { resultArray = new int[lines]; this.lines = lines; this.power = power; deal(0); return resultList; } private void deal(int m) { for (int i = 0; i <= power; i++) { resultArray[m] = i; if (m == lines - 1) { //如果找到一个解 if (check()) { print(); return; } } else { deal(m + 1); } } } /** * 判断是否符合结果 * * @return 是否符合 */ private boolean check() { int total = 0; for (int one : resultArray) { total += one; } return power == total; } private void print() { for (int one : resultArray) { System.out.print(one); } System.out.print("\n"); int[] temp = new int[resultArray.length]; System.arraycopy(resultArray, 0, temp, 0, resultArray.length); resultList.add(temp); } }
package ploy; import java.io.*; import java.util.List; /** * @author wuweifeng wrote on 2018/6/5. */ public class TextDeal { public static void main(String[] args) throws IOException { new TextDeal().linePower("/Users/wuwf/Downloads/ml_data/1逻辑回归入门/data11.csv", "/Users/wuwf/Downloads/ml_data/1逻辑回归入门/data_new.csv", 2, 0,1); } /** * @param filePath * 文件的路径 * @param outputPath * 输出文件的路径 * @param power * 要做几次方 * @param lineNums * 都有哪几列,需要power,不填默认所有列。从第0列开始 */ public void linePower(String filePath, String outputPath, Integer power, Integer... lineNums) throws IOException { BufferedReader reader = buildReader(filePath); BufferedWriter writer = buildWriter(outputPath); addCSVHeader(reader, writer, power, lineNums); } private Integer[] getLineNums(String[] lines, Integer... lineNums) { //为null,则是所有列 if (lineNums == null) { lineNums = new Integer[lines.length]; for (int i = 0; i < lines.length; i++) { lineNums[i] = i; } } return lineNums; } private List<int[]> getAddList(int power, Integer... lineNums) { LineAdder lineAdder = new LineAdder(); //计算共需增加多少列 return lineAdder.lineAdd(lineNums.length, power); } /** * 给header里增加相应的列名,都在第一行 */ private void addCSVHeader(BufferedReader reader, BufferedWriter writer, Integer power, Integer... lineNums) throws IOException { //读取第一行 String header = reader.readLine(); //所有的列名 String[] lines = header.split(","); lineNums = getLineNums(lines, lineNums); //计算共需增加多少列 List<int[]> list = getAddList(power, lineNums); String[] addLines = new String[list.size()]; String[] needLines = new String[lineNums.length]; for (int i = 0; i < lineNums.length; i++) { needLines[i] = lines[lineNums[i]]; } //设置每一列的名字 for (int i = 0; i < list.size(); i++) { int[] array = list.get(i); String s = ""; for (int j = 0; j < array.length; j++) { s += needLines[j] + array[j]; } addLines[i] = s; } for (String addLine : addLines) { header += "," + addLine; } //将新增的列,写入header文件 writer.write(header); writer.newLine(); writer.flush(); String oneLine; while ((oneLine = reader.readLine()) != null) { addLines = new String[list.size()]; lines = oneLine.split(","); needLines = new String[lineNums.length]; for (int i = 0; i < lineNums.length; i++) { needLines[i] = lines[lineNums[i]]; } //设置每一列的值 for (int i = 0; i < list.size(); i++) { int[] array = list.get(i); double s = 1; for (int j = 0; j < array.length; j++) { //譬如a,b,对应02时,该列就是a的0次方乘以b的2次方 s *= Math.pow(Double.valueOf(needLines[j]), array[j]); } addLines[i] = s + ""; } for (String addLine : addLines) { oneLine += "," + addLine; } writer.write(oneLine); //写入相关文件 writer.newLine(); } //将新增的列,写入header文件 writer.flush(); //关闭流 reader.close(); writer.close(); } private BufferedReader buildReader(String filePath) { try { return new BufferedReader(new FileReader(new File(filePath))); } catch (FileNotFoundException e) { e.printStackTrace(); return null; } } private BufferedWriter buildWriter(String outputPath) { //写入相应的文件 try { return new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputPath), "utf-8")); } catch (UnsupportedEncodingException | FileNotFoundException e) { e.printStackTrace(); return null; } } }
看效果:
假如csv文件是这样的
a,b
1,2
2,3
4,5
运行后,结果是
a,b,a0b2,a1b1,a2b0
1,2,4.0,2.0,1.0
2,3,9.0,6.0,4.0
4,5,25.0,20.0,16.0
可以看到已经完成了做2次方的展开。
这个类,可以完成任意次方的模拟及计算。