KNN python code

几分钟写了个KNN Python代码,在编译器上可以直接跑:


"""
programs: KNN algorithm
description:
1.calculate the distance between test data and every single train data
2.sort the distance 
3.select the minimum k points by distance
4.count the label frequency of k points
5.return to the label of the highest frequency

"""
from mlxtend.data import iris_data
import numpy as np



class knn_csy(object):
    def __init__(self,dataset,label):
        self.dataset=dataset
        self.label=label
    def distance(self,dataset_i,testdata):
        dist=np.sum((dataset_i-testdata)**2)
        return np.sqrt(dist)

    def calculate_dis(self,testdata,k=10,updateflage=0):
        """
        
        :param testdata: 
        :param k: default by 10
        :param updateflage: 
        :return: 
        """
        if len(testdata)!=len(self.dataset[0]):
            raise Exception("wrong input array of testdata");
        dis=[]
        dimension=len(self.dataset)
        for i in range(dimension):
            distance=self.distance(self.dataset[i],testdata)
            dis.append(distance)
        dic=zip(dis,self.label)
        dic=sorted(dic)
        label=[]
        for i in range(k):
            label.append(dic[i][1])
        count=np.bincount(label)
        label=np.argmax(count)
        if updateflage:
            self.dataset.append(testdata)
            self.label.append(label)
        return label

if __name__ == '__main__':
    dataset,label=iris_data()
    myknn=knn_csy(dataset,label)
    testdata=[2,1,1,2]
    label=myknn.calculate_dis(testdata,3)
    print label









猜你喜欢

转载自blog.csdn.net/lisarer/article/details/78688403