最邻近值算法实现
工程目录结构
代码
训练模型
package com.knn.model
/**
* 训练数据模型
*
* @param aA 数据a
* @param bA 数据b
* @param typeA 类型
*/
class KNNModel(aA:Double,bA:Double,typeA:String) {
var a:Double = aA
var b:Double = bA
var resType: String = typeA
var distince:Double = 0
}
核心算法代码
package com.knn.core
import com.knn.model.KNNModel
import scala.collection.immutable.ListMap
/**
* 最邻近算法核心算法
*/
class KNN_Core {
/**
* 对训练数据进行升序排序(根据距离来进行排序)
* @param knnMOdels
* @return
*/
private def sortByDistince(knnMOdels:List[KNNModel]):List[KNNModel] ={
return knnMOdels.sortBy(knn => knn.distince)
}
/**
* 使用欧几里得度量计算出距离
* @param knnMOdels
* @param k
*/
private def coluaclateDistince(knnMOdels:List[KNNModel],k: KNNModel):Unit =
knnMOdels.foreach(n=>{
n.distince = Math.sqrt((k.a-n.a)*(k.a-n.a)+(k.b-n.b)*(k.b-n.b))
})
/**
* 获取距离目标点附近(指定集合大小的范围内存在最多的数据)
* @param ks
* @return
*/
private def findMostValue(ks:List[KNNModel]):String ={
var resType = ""
var typeCountMap:Map[String,Int] = Map()
ks.toStream.foreach(k=>{
if (typeCountMap.contains(k.resType)){
typeCountMap+= (k.resType -> (typeCountMap(k.resType)+1))
}else{
typeCountMap+=(k.resType -> 1)
}
})
resType = ListMap(typeCountMap.toSeq.sortWith(_._2 >_._2):_*).take(1).keySet.head
return resType
}
def reckonRelize(kns:List[KNNModel],kn:KNNModel,k: Int):String={
coluaclateDistince(kns,kn)
var knsSort = sortByDistince(kns)
var knss = knsSort.take(k)
return findMostValue(knss)
}
}
运行代码
package com.knn
import com.knn.core.KNN_Core
import com.knn.model.KNNModel
/**
* 分割类
*/
object app {
def main(args: Array[String]): Unit = {
var knnModels:List[KNNModel] = List()
knnModels = knnModels.::(new KNNModel(1.1, 1.1, "A"))
knnModels = knnModels.::(new KNNModel(1.2, 1.2, "A"))
knnModels = knnModels.::(new KNNModel(1.1, 1.0, "A"))
knnModels = knnModels.::(new KNNModel(3.0, 3.1, "B"))
knnModels = knnModels.::(new KNNModel(3.1, 3.0, "B"))
knnModels = knnModels.::(new KNNModel(5.4, 6.0, "C"))
knnModels = knnModels.::(new KNNModel(5.5, 6.3, "C"))
knnModels = knnModels.::(new KNNModel(6.0, 12.0, "C"))
knnModels = knnModels.::(new KNNModel(10.0, 12.0, "M"))
var knnModle = new KNNModel(4.0, 3.2, "A")
var kNN_Core = new KNN_Core
var resType = kNN_Core.reckonRelize(knnModels,knnModle,5)
println("预测结果",resType)
}
}
参考资料
KNN
欧几里得度量