package com.haolidong.kNN;
import java.util.Comparator;
/**
*
* @author haolidong
* @Description: [该类主要用于对距离信息的index进行自定义的排序(从大到小)]
*/
public class ComparatorImpl implements Comparator<Distances>{
@Override
public int compare(Distances arg0, Distances arg1) {
// TODO Auto-generated method stub
double d0=arg0.getDistances();
double d1=arg1.getDistances();
if(d0>d1){
return 1;
}
else if(d0<d1){
return -1;
}
else{
return 0;
}
}
}
package com.haolidong.kNN;
/**
*
* @author haolidong
* @Description: [该类主要用于保存KNN的距离信息以及index]
*/
public class Distances {
double distances;
public Distances()
{
distances=0.0;
sortedDistIndicies=0;
}
public Distances(double distances, int sortedDistIndicies) {
super();
this.distances = distances;
this.sortedDistIndicies = sortedDistIndicies;
}
int sortedDistIndicies;
public double getDistances() {
return distances;
}
public void setDistances(double distances) {
this.distances = distances;
}
public int getSortedDistIndicies() {
return sortedDistIndicies;
}
public void setSortedDistIndicies(int sortedDistIndicies) {
this.sortedDistIndicies = sortedDistIndicies;
}
}
package com.haolidong.kNN;
import java.util.ArrayList;
/**
* @author haolidong
* @Description: [该类主要用于保存信息矩阵以及矩阵标签]
*/
public class ReturnML {
public ArrayList<ArrayList<Double>> AR;
public ArrayList<String> AS;
public ReturnML() {
// TODO Auto-generated constructor stub
AR = new ArrayList<ArrayList<Double>>();
AS = new ArrayList<String>();
}
}
package com.haolidong.kNN;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.Set;
public class KNN {
public static ReturnML rml = new ReturnML();
/**
* @param args
* @throws IOException
* @author haolidong
* @Description: [主函数主要对于三个案例进行测试,分别为简单分类,约会测试以及手写识别]
*/
public static void main(String[] args) throws IOException {
testClassify();
datingClassTest();
handwritingClassTest();
}
/**
* @author haolidong
* @Description: [简单的通过文本文件创建二维矩阵并输出]
*/
public static void test(){
file2matrix("I:\\10yue1\\machinelearninginaction\\Ch02\\datingTestSet2.txt");
autoNorm();
for (int i = 0; i < rml.AR.size(); i++) {
System.out.print(i+": ");
for (int j = 0; j < rml.AR.get(i).size(); j++) {
System.out.print(rml.AR.get(i).get(j)+" ");
}
System.out.println(rml.AS.get(i));
}
}
/**
* @return 返回标签号
* @author haolidong
* @Description: [函数主要对于KNN的简单分类]
*/
public static String testClassify(){
ArrayList<ArrayList<Double>> group = new ArrayList<ArrayList<Double>>();
ArrayList<String> labels = new ArrayList<String>();
ArrayList<Double> input = new ArrayList<Double>();
input.add(0.0);
input.add(0.0);
ArrayList<Double> a1 = new ArrayList<Double>();
a1.add(1.0);
a1.add(1.1);
ArrayList<Double> a2 = new ArrayList<Double>();
a2.add(1.0);
a2.add(1.0);
ArrayList<Double> a3 = new ArrayList<Double>();
a3.add(0.0);
a3.add(0.0);
ArrayList<Double> a4 = new ArrayList<Double>();
a4.add(0.0);
a4.add(0.1);
group.add(a1);
group.add(a2);
group.add(a3);
group.add(a4);
labels.add("A");
labels.add("A");
labels.add("B");
labels.add("B");
String lab=classify(input,group,labels,3);
System.out.println(lab);
return lab;
}
/**
* @param inX 测试用例的输入
* @param dataSet 训练数据矩阵
* @param labels 训练数据标签
* @param k kNN中对于前面K项的排名
* @return 测试用例的标签
* @author haolidong
* @Description: [KNN的核心分类算法]
*/
public static String classify(ArrayList<Double> inX,ArrayList<ArrayList<Double>> dataSet,ArrayList<String> labels,int k)
{
ArrayList<ArrayList<Double>> dataCopy = new ArrayList<ArrayList<Double>>();
for (int i = 0; i < dataSet.size(); i++) {
ArrayList<Double> ad = new ArrayList<>();
for (int j = 0; j < dataSet.get(i).size(); j++) {
ad.add(dataSet.get(i).get(j));
}
dataCopy.add(ad);
}
ArrayList<Distances> dis = new ArrayList<Distances>();
for (int i = 0; i < dataCopy.size(); i++) {
dis.add(new Distances());
}
for (int i = 0; i < dataCopy.size(); i++) {
for (int j = 0; j < dataCopy.get(i).size(); j++) {
dataCopy.get(i).set(j, inX.get(j)-dataCopy.get(i).get(j));
}
}
for (int i = 0; i < dataCopy.size(); i++) {
dis.get(i).setSortedDistIndicies(i);
double distan = 0.0;
for (int j = 0; j < dataCopy.get(i).size(); j++) {
distan = distan + dataCopy.get(i).get(j)*dataCopy.get(i).get(j);
}
dis.get(i).setDistances(Math.sqrt(distan));
}
Comparator<Distances> comp = new ComparatorImpl();
Collections.sort(dis, comp);
HashMap<String,Integer> classCount = new HashMap<String,Integer>();
String voteIlabel;
for (int i = 0; i < k; i++) {
voteIlabel = labels.get(dis.get(i).getSortedDistIndicies());
if(classCount.containsKey(voteIlabel)==true){
classCount.put(voteIlabel, classCount.get(voteIlabel)+1);
}else{
classCount.put(voteIlabel, 1);
}
}
classCount = sortMap(classCount);
Set<Entry<String, Integer>> set = classCount.entrySet();
Iterator<Entry<String, Integer>> it = set.iterator();
return (String) it.next().getKey();
}
public static HashMap<String,Integer> sortMap(HashMap<String,Integer> oldMap) {
ArrayList<HashMap.Entry<String, Integer>> list = new ArrayList<HashMap.Entry<String, Integer>>(oldMap.entrySet());
Collections.sort(list, new Comparator<HashMap.Entry<String, Integer>>() {
@Override
public int compare(Entry<java.lang.String, Integer> arg0,
Entry<java.lang.String, Integer> arg1) {
return arg1.getValue() - arg0.getValue();
}
});
HashMap<String, Integer> newMap = new LinkedHashMap<String, Integer>();
for (int i = 0; i < list.size(); i++) {
newMap.put(list.get(i).getKey(), list.get(i).getValue());
}
return newMap;
}
/**
* @param fileName 文件名
* @author haolidong
* @Description: [读入文件然后转化为数组矩阵]
*/
public static void file2matrix(String fileName){
File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
// 显示行号
String[] strArr = tempString.split("\t");
ArrayList<Double> ad = new ArrayList<Double>();
for (int i = 0; i < strArr.length-1; i++) {
ad.add(Double.parseDouble(strArr[i]));
}
rml.AR.add(ad);
rml.AS.add(new String(strArr[strArr.length-1]));
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
}
/**
* @author haolidong
* @Description: [对于输入矩阵的归一化:X:(X-min)/(max-min)]
*/
public static void autoNorm(){
ArrayList<Double> min = new ArrayList<Double>();
ArrayList<Double> max = new ArrayList<Double>();
ArrayList<Double> range = new ArrayList<Double>();
for (int j = 0; j < rml.AR.get(0).size(); j++) {
min.add(rml.AR.get(0).get(j));
max.add(rml.AR.get(0).get(j));
}
for (int i = 0; i < rml.AR.size(); i++) {
for (int j = 0; j < rml.AR.get(i).size(); j++) {
if(rml.AR.get(i).get(j)>max.get(j)){
max.set(j, rml.AR.get(i).get(j));
}
if(rml.AR.get(i).get(j)<min.get(j)){
min.set(j, rml.AR.get(i).get(j));
}
}
}
for (int j = 0; j < rml.AR.get(0).size(); j++) {
range.add(max.get(j)-min.get(j));
}
for (int i = 0; i < rml.AR.size(); i++) {
for (int j = 0; j < rml.AR.get(i).size(); j++) {
rml.AR.get(i).set(j, (rml.AR.get(i).get(j)-min.get(j))/range.get(j));
}
}
}
/**
* @author haolidong
* @Description: [约会的分类案例]
*/
public static void datingClassTest(){
double hoRatio = 0.50;
file2matrix("I:\\10yue1\\machinelearninginaction\\Ch02\\datingTestSet2.txt");
autoNorm();
int m = rml.AR.size();
int numTestVecs = (int) (m*hoRatio);
ArrayList<ArrayList<Double>> group = new ArrayList<ArrayList<Double>>();
ArrayList<String> labels = new ArrayList<String>();
autoNorm();
for (int i = 0; i < rml.AR.size()-numTestVecs; i++) {
ArrayList<Double> ad = new ArrayList<Double>();
for (int j = 0; j < rml.AR.get(i).size(); j++) {
ad.add(rml.AR.get(i+numTestVecs).get(j));
}
group.add(ad);
labels.add(rml.AS.get(i+numTestVecs));
}
int errorCount1 = 0;
int s1,s2;
for (int i = 0; i < numTestVecs; i++) {
s1=Integer.parseInt(classify(rml.AR.get(i),group,labels,3));
s2=Integer.parseInt(rml.AS.get(i).trim());
System.out.println("the classifier came back with: "+s1+" the real answer is: "+s2);
if(s1!=s2)
{
errorCount1++;
}
}
System.out.println("the total error rate is:"+1.0*errorCount1/numTestVecs);
}
/**
* @param file 输入的二进制图片文件
* @return 返回图像矩阵
* @author haolidong
* @Description: [二进制图片文件转化为图像矩阵]
*/
public static ArrayList<Double> img2vector(File file){
ArrayList<Double> ad = new ArrayList<Double>();
//File file = new File(fileName);
BufferedReader reader = null;
try {
reader = new BufferedReader(new FileReader(file));
String tempString = null;
while ((tempString = reader.readLine()) != null) {
char[] ch=tempString.toCharArray();
for (int i = 0; i < ch.length; i++) {
double d=Integer.parseInt(ch[i]+"");
ad.add(d);
//ad.add(1.0*);
}
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return ad;
}
/**
* @throws IOException
* @author haolidong
* @Description: [手写识别]
*/
public static void handwritingClassTest() throws IOException
{
ArrayList<ArrayList<Double>> vectorUnderTest = new ArrayList<ArrayList<Double>>();
ArrayList<ArrayList<Double>> trainingMat = new ArrayList<ArrayList<Double>>();
ArrayList<String> trainingLabel = new ArrayList<String>();
ArrayList<String> testLabel = new ArrayList<String>();
String pathTest="I:\\machinelearninginaction\\Ch02\\testDigits\\";
String pathTrain="I:\\machinelearninginaction\\Ch02\\trainingDigits\\";
File fileTrain=new File(pathTrain);
File fileTest=new File(pathTest);
File[] trainList = fileTrain.listFiles();
File[] testList = fileTest.listFiles();
String tmp = new String();
for (int i = 0; i < trainList.length; i++) {
if (trainList[i].isFile()) {
ArrayList<Double> ad = new ArrayList<Double>();
ad = img2vector(trainList[i]);
tmp=trainList[i].getCanonicalPath();
trainingLabel.add(tmp.substring(tmp.lastIndexOf('\\')+1, tmp.indexOf('_')));
trainingMat.add(ad);
}
}
for (int i = 0; i < testList.length; i++) {
if (testList[i].isFile()) {
ArrayList<Double> adt = new ArrayList<Double>();
adt = img2vector(testList[i]);
tmp=testList[i].getCanonicalPath();
testLabel.add(tmp.substring(tmp.lastIndexOf('\\')+1, tmp.indexOf('_')));
vectorUnderTest.add(adt);
}
}
String classifyLabel;
int errorCount = 0;
for (int i = 0; i < testList.length; i++) {
classifyLabel = classify(vectorUnderTest.get(i), trainingMat, trainingLabel, 3);
System.out.println("the classifier came back with:"+classifyLabel+", the real answer is:"+testLabel.get(i));
if(!classifyLabel.equals(testLabel.get(i))){
errorCount++;
}
}
System.out.println("the total number of errors is:"+errorCount);
System.out.println("the total error rate is: "+1.0*errorCount/testList.length);
}
}