临近算法

还是图谱推荐项目,最初的设计是通过临近算法处理推荐,但是实现过后又被放弃了,原因是领导的决策变了。。。
先介绍一下临近算法:
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。(百度百科)
我自己的理解大致分下面几个步骤:
1、 从测试集中取数据,分别和已经存在的点计算欧氏距离
2、对欧氏距离从小到大进行排序,取前k个临近的点
3、计算这k个点在各个社区中出现的频率,频率最大的社区就是目标点所在的社区
下面是java代码

public class Point {
    private long id;
    private double a;
    private double b;
    private double c;
    private double d;
    private double e;
    private double f;
    private String type;

    public Point(long id, double a, double b, double c, double d, double e, double f) {
        this.id = id;
        this.a = a;
        this.b = b;
        this.c = c;
        this.d = d;
        this.e = e;
        this.f = f;
    }

    public Point(long id, double a, double b, double c, double d, double e, double f,String type) {
        this.id = id;
        this.a = a;
        this.b = b;
        this.c = c;
        this.d = d;
        this.e = e;
        this.f = f;
        this.type = type;
    }
	  get和set方法省略

距离

public class Distance {
    private long id;
    private long nid;
    private double disatance;

    public Distance(long id, long nid, double disatance) {
        this.id = id;
        this.nid = nid;
        this.disatance = disatance;
    }

    public long getId() {
        return id;
    }

    public void setId(long id) {
        this.id = id;
    }

    public long getNid() {
        return nid;
    }

    public void setNid(long nid) {
        this.nid = nid;
    }

    public double getDisatance() {
        return disatance;
    }

    public void setDisatance(double disatance) {
        this.disatance = disatance;
    }
}

实现主类

public class AtlasRecommendKnn {
    public static void main(String[] args) {

        // 一、输入所有已知点
//        List<Point> dataList = creatDataSet();
        // 二、输入未知点
//        Point x = new Point(5, 1.2, 1.2);
        // 三、计算所有已知点到未知点的欧式距离,并根据距离对所有已知点排序
//        CompareClass compare = new CompareClass();
//        Set<Distance> distanceSet = new TreeSet<Distance>(compare);
//        for (Point point : dataList) {
//            distanceSet.add(new Distance(point.getId(), x.getId(), oudistance(point,
//                    x)));
//        }
        // 四、选取最近的k个点
//        double k = 5;

        /**
         * 五、计算k个点所在分类出现的频率
         */
        // 1、计算每个分类所包含的点的个数
//        List<Distance> distanceList= new ArrayList<Distance>(distanceSet);
//        Map<String, Integer> map = getNumberOfType(distanceList, dataList, k);

        // 2、计算频率
//        Map<String, Double> p = computeP(map, k);
//        x.setType(maxP(p));
//        System.out.println("未知点的类型为:"+x.getType());
    }

    // 欧式距离计算
    public static double oudistance(Point point1, Point point2) {
        //判断目标基金的属性是否是空(值为-1则是空)
        double temp = 0.0;
        if (point1.getA() != -1){
            temp = temp + Math.pow(point1.getA() + point2.getA(),2);
        }
        if (point1.getB() != -1){
            temp = temp + Math.pow(point1.getB() + point2.getB(),2);
        }
        if (point1.getC() != -1){
            temp = temp + Math.pow(point1.getC() + point2.getC(),2);
        }
        if (point1.getD() != -1){
            temp = temp + Math.pow(point1.getD() + point2.getD(),2);
        }
        if (point1.getE() != -1){
            temp = temp + Math.pow(point1.getE() + point2.getE(),2);
        }
        if (point1.getF() != -1){
            temp = temp + Math.pow(point1.getF() + point2.getF(),2);
        }
        /*double temp1 = Math.pow(point1.getA() + point2.getA(),2)
                + Math.pow(point1.getB() + point2.getB(),2)
                + Math.pow(point1.getC() + point2.getC(),2)
                + Math.pow(point1.getD() + point2.getD(),2)
                + Math.pow(point1.getE() + point2.getE(),2)
                + Math.pow(point1.getF() + point2.getF(),2)
                + Math.pow(point1.getG() + point2.getG(),2)
                + Math.pow(point1.getH() + point2.getH(),2);*/
        return Math.sqrt(temp);
    }

    // 找出最大频率
    public static String maxP(Map<String, Double> map) {
        String key = null;
        double value = 0.0;
        for (Map.Entry<String, Double> entry : map.entrySet()) {
            if (entry.getValue() > value) {
                key = entry.getKey();
                value = entry.getValue();
            }
        }
        return key;
    }

    // 计算频率
    public static Map<String, Double> computeP(Map<String, Integer> map,
                                               double k) {
        Map<String, Double> p = new HashMap<String, Double>();
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            p.put(entry.getKey(), entry.getValue() / k);
        }
        return p;
    }

    // 计算每个分类包含的点的个数
    public static Map<String, Integer> getNumberOfType(
            List<Distance> listDistance, List<Point> listPoint, double k) {
        Map<String, Integer> map = new HashMap<String, Integer>();
        int i = 0;
        System.out.println("选取的k个点,由近及远依次为:");
        for (Distance distance : listDistance) {
            System.out.println("id为" + distance.getId() + ",距离为:"
                    + distance.getDisatance());
            long id = distance.getId();
            // 通过id找到所属类型,并存储到HashMap中
            for (Point point : listPoint) {
                if (point.getId() == id) {
                    if (map.get(point.getType()) != null)
                        map.put(point.getType(), map.get(point.getType()) + 1);
                    else {
                        map.put(point.getType(), 1);
                    }
                }
            }
            i++;
            if (i >= k)
                break;
        }
        return map;
    }

    public static ArrayList<Point> creatDataSet(Map<String, List<Double>> data){

        ArrayList<Point> dataList = new ArrayList<Point>();
        Set<String> keySet = data.keySet();
        for (String string : keySet){
            List<Double> list = data.get(string);
            Point point1 = new Point(Long.parseLong(string.split("-")[0]), list.get(0), list.get(1), list.get(2), list.get(3)
                    , list.get(4), list.get(5),string.split("-")[1]);
        }

//        Point point1 = new Point(1.0, 1, 1.1, "A");
//        Point point2 = new Point(2, 1.0, 1.0, "A");
//        Point point3 = new Point(3, 1.0, 1.1, "B");
//        Point point4 = new Point(4, 0, 0, "B");
//        Point point5 = new Point(3, 1, 0.1, "C");
//        Point point6 = new Point(6, 0, 0.2, "C");

        return dataList;
    }
}

猜你喜欢

转载自blog.csdn.net/love_zy0216/article/details/88866231