KNN(最邻近值算法) scala实现

最邻近值算法实现

工程目录结构

这里写图片描述

代码

训练模型

package com.knn.model

/**
  * 训练数据模型
  *
  * @param aA 数据a
  * @param bA 数据b
  * @param typeA 类型
  */
class KNNModel(aA:Double,bA:Double,typeA:String) {
  var a:Double = aA
  var b:Double = bA
  var resType: String = typeA
  //距离
  var distince:Double = 0
}

核心算法代码

package com.knn.core

import com.knn.model.KNNModel

import scala.collection.immutable.ListMap

/**
  * 最邻近算法核心算法
  */
class KNN_Core {
//  val knnModel = new KNNModel(null,null,null,null,null);

  /**
    * 对训练数据进行升序排序(根据距离来进行排序)
    * @param knnMOdels
    * @return
    */
  private def sortByDistince(knnMOdels:List[KNNModel]):List[KNNModel] ={
    //进行升序排序
    return knnMOdels.sortBy(knn => knn.distince)
  }

  /**
    * 使用欧几里得度量计算出距离
    * @param knnMOdels
    * @param k
    */
  private def coluaclateDistince(knnMOdels:List[KNNModel],k: KNNModel):Unit =
    knnMOdels.foreach(n=>{
      n.distince = Math.sqrt((k.a-n.a)*(k.a-n.a)+(k.b-n.b)*(k.b-n.b))
    })

  /**
    * 获取距离目标点附近(指定集合大小的范围内存在最多的数据)
    * @param ks
    * @return
    */
  private def findMostValue(ks:List[KNNModel]):String ={
    //找出训练集中在规定数量中存在最多的类
    var resType = ""
    var typeCountMap:Map[String,Int] = Map()
    //进行计数
    ks.toStream.foreach(k=>{
      if (typeCountMap.contains(k.resType)){
        typeCountMap+= (k.resType -> (typeCountMap(k.resType)+1))
      }else{
        typeCountMap+=(k.resType -> 1)
      }
    })
    //获取最多数量类型(根据键值进行排序)
    resType = ListMap(typeCountMap.toSeq.sortWith(_._2 >_._2):_*).take(1).keySet.head
    return resType
  }

  def reckonRelize(kns:List[KNNModel],kn:KNNModel,k: Int):String={
    //计算距离
    coluaclateDistince(kns,kn)
    //根据距离排序
    var knsSort = sortByDistince(kns)
    //获取前k个数据
    var knss = knsSort.take(k)
    //获取k个数据中数量最多的类型
    return findMostValue(knss)
  }
}

运行代码

package com.knn

import com.knn.core.KNN_Core
import com.knn.model.KNNModel

/**
  * 分割类
  */
object app {
  def main(args: Array[String]): Unit = {
    //数据准备
    var knnModels:List[KNNModel] = List()
    knnModels = knnModels.::(new KNNModel(1.1, 1.1, "A"))
    knnModels = knnModels.::(new KNNModel(1.2, 1.2, "A"))
    knnModels = knnModels.::(new KNNModel(1.1, 1.0, "A"))
    knnModels = knnModels.::(new KNNModel(3.0, 3.1, "B"))
    knnModels = knnModels.::(new KNNModel(3.1, 3.0, "B"))
    knnModels = knnModels.::(new KNNModel(5.4, 6.0, "C"))
    knnModels = knnModels.::(new KNNModel(5.5, 6.3, "C"))
    knnModels = knnModels.::(new KNNModel(6.0, 12.0, "C"))
    knnModels = knnModels.::(new KNNModel(10.0, 12.0, "M"))
    //待预测数据
    var knnModle = new KNNModel(4.0, 3.2, "A")

    var kNN_Core = new KNN_Core
    //算法实现
    var resType = kNN_Core.reckonRelize(knnModels,knnModle,5)
    println("预测结果",resType)
  }
}

参考资料

KNN
欧几里得度量

猜你喜欢

转载自blog.csdn.net/xuzz94/article/details/79694476