还是图谱推荐项目,最初的设计是通过临近算法处理推荐,但是实现过后又被放弃了,原因是领导的决策变了。。。
先介绍一下临近算法:
kNN算法的核心思想是如果一个样本在特征空间中的k个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。(百度百科)
我自己的理解大致分下面几个步骤:
1、 从测试集中取数据,分别和已经存在的点计算欧氏距离
2、对欧氏距离从小到大进行排序,取前k个临近的点
3、计算这k个点在各个社区中出现的频率,频率最大的社区就是目标点所在的社区
下面是java代码
点
public class Point {
private long id;
private double a;
private double b;
private double c;
private double d;
private double e;
private double f;
private String type;
public Point(long id, double a, double b, double c, double d, double e, double f) {
this.id = id;
this.a = a;
this.b = b;
this.c = c;
this.d = d;
this.e = e;
this.f = f;
}
public Point(long id, double a, double b, double c, double d, double e, double f,String type) {
this.id = id;
this.a = a;
this.b = b;
this.c = c;
this.d = d;
this.e = e;
this.f = f;
this.type = type;
}
get和set方法省略
距离
public class Distance {
private long id;
private long nid;
private double disatance;
public Distance(long id, long nid, double disatance) {
this.id = id;
this.nid = nid;
this.disatance = disatance;
}
public long getId() {
return id;
}
public void setId(long id) {
this.id = id;
}
public long getNid() {
return nid;
}
public void setNid(long nid) {
this.nid = nid;
}
public double getDisatance() {
return disatance;
}
public void setDisatance(double disatance) {
this.disatance = disatance;
}
}
实现主类
public class AtlasRecommendKnn {
public static void main(String[] args) {
// 一、输入所有已知点
// List<Point> dataList = creatDataSet();
// 二、输入未知点
// Point x = new Point(5, 1.2, 1.2);
// 三、计算所有已知点到未知点的欧式距离,并根据距离对所有已知点排序
// CompareClass compare = new CompareClass();
// Set<Distance> distanceSet = new TreeSet<Distance>(compare);
// for (Point point : dataList) {
// distanceSet.add(new Distance(point.getId(), x.getId(), oudistance(point,
// x)));
// }
// 四、选取最近的k个点
// double k = 5;
/**
* 五、计算k个点所在分类出现的频率
*/
// 1、计算每个分类所包含的点的个数
// List<Distance> distanceList= new ArrayList<Distance>(distanceSet);
// Map<String, Integer> map = getNumberOfType(distanceList, dataList, k);
// 2、计算频率
// Map<String, Double> p = computeP(map, k);
// x.setType(maxP(p));
// System.out.println("未知点的类型为:"+x.getType());
}
// 欧式距离计算
public static double oudistance(Point point1, Point point2) {
//判断目标基金的属性是否是空(值为-1则是空)
double temp = 0.0;
if (point1.getA() != -1){
temp = temp + Math.pow(point1.getA() + point2.getA(),2);
}
if (point1.getB() != -1){
temp = temp + Math.pow(point1.getB() + point2.getB(),2);
}
if (point1.getC() != -1){
temp = temp + Math.pow(point1.getC() + point2.getC(),2);
}
if (point1.getD() != -1){
temp = temp + Math.pow(point1.getD() + point2.getD(),2);
}
if (point1.getE() != -1){
temp = temp + Math.pow(point1.getE() + point2.getE(),2);
}
if (point1.getF() != -1){
temp = temp + Math.pow(point1.getF() + point2.getF(),2);
}
/*double temp1 = Math.pow(point1.getA() + point2.getA(),2)
+ Math.pow(point1.getB() + point2.getB(),2)
+ Math.pow(point1.getC() + point2.getC(),2)
+ Math.pow(point1.getD() + point2.getD(),2)
+ Math.pow(point1.getE() + point2.getE(),2)
+ Math.pow(point1.getF() + point2.getF(),2)
+ Math.pow(point1.getG() + point2.getG(),2)
+ Math.pow(point1.getH() + point2.getH(),2);*/
return Math.sqrt(temp);
}
// 找出最大频率
public static String maxP(Map<String, Double> map) {
String key = null;
double value = 0.0;
for (Map.Entry<String, Double> entry : map.entrySet()) {
if (entry.getValue() > value) {
key = entry.getKey();
value = entry.getValue();
}
}
return key;
}
// 计算频率
public static Map<String, Double> computeP(Map<String, Integer> map,
double k) {
Map<String, Double> p = new HashMap<String, Double>();
for (Map.Entry<String, Integer> entry : map.entrySet()) {
p.put(entry.getKey(), entry.getValue() / k);
}
return p;
}
// 计算每个分类包含的点的个数
public static Map<String, Integer> getNumberOfType(
List<Distance> listDistance, List<Point> listPoint, double k) {
Map<String, Integer> map = new HashMap<String, Integer>();
int i = 0;
System.out.println("选取的k个点,由近及远依次为:");
for (Distance distance : listDistance) {
System.out.println("id为" + distance.getId() + ",距离为:"
+ distance.getDisatance());
long id = distance.getId();
// 通过id找到所属类型,并存储到HashMap中
for (Point point : listPoint) {
if (point.getId() == id) {
if (map.get(point.getType()) != null)
map.put(point.getType(), map.get(point.getType()) + 1);
else {
map.put(point.getType(), 1);
}
}
}
i++;
if (i >= k)
break;
}
return map;
}
public static ArrayList<Point> creatDataSet(Map<String, List<Double>> data){
ArrayList<Point> dataList = new ArrayList<Point>();
Set<String> keySet = data.keySet();
for (String string : keySet){
List<Double> list = data.get(string);
Point point1 = new Point(Long.parseLong(string.split("-")[0]), list.get(0), list.get(1), list.get(2), list.get(3)
, list.get(4), list.get(5),string.split("-")[1]);
}
// Point point1 = new Point(1.0, 1, 1.1, "A");
// Point point2 = new Point(2, 1.0, 1.0, "A");
// Point point3 = new Point(3, 1.0, 1.1, "B");
// Point point4 = new Point(4, 0, 0, "B");
// Point point5 = new Point(3, 1, 0.1, "C");
// Point point6 = new Point(6, 0, 0.2, "C");
return dataList;
}
}