K近邻算法

本文用于练习K近邻算法的代码实现!!!

算法思想:
通过计算每个测试样本到训练样本的距离,取和测试样本最近的k个训练样本的标签,哪个类的训练样本

最多,则测试样本就属于此类。

优点:思想简单;不需要训练过程;适合多分类问题

缺点:当训练样本较多时,计算复杂度较高;

步骤:
1.计算测试样本到各训练样本的欧式距离
2.对距离进行排序,获得最近k个样本的标签

3.比较k个样本标签出现最多的即为测试样本标签

scikit-learn机器学习算法库实现。

from sklearn import neighbors

#自定义二维列表训练集
train = [[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]]
label = [-1,-1,1,1]

#获得分类器
knn = neighbors.KNeighborsClassifier(n_neighbors=2)

knn.fit(train, label)

print( knn.predict([[1.0,0.9]]) )
#返回测试数据各元素的概率
print( knn.predict_proba([[1.0,0.9]]) )

Python实现

import numpy as np

#自定义训练数据集并保存txt文件
Train = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
Label = np.array([[-1],[-1],[1],[1]])
TrainSet = np.concatenate((Train,Label),axis=1)
print( TrainSet.shape )
np.savetxt("KNNTrainSet.txt",TrainSet)

#自定义测试数据集并保存txt文件
TestSet = np.array([[1.0,0.9],[0.1,0.2]])
np.savetxt("KNNTestSet.txt",TestSet)

#提取训练数据和测试数据
a = np.loadtxt('KNNTrainSet.txt',usecols=(0,1),unpack=True)
train=a.T
print( train.shape )
b = np.loadtxt('KNNTrainSet.txt',usecols=(2),unpack=True)
label=b.T
c = np.loadtxt('KNNTestSet.txt',usecols=(0,1),unpack=True)
test=c.T

#构建KNN分类器
def Classifier(t,x,y,k):
    #获得训练样本数
    trainSize = x.shape[0]
    #测试数据与每个训练数据的差值
    diff = np.tile(t,(trainSize,1)) - x
    sqdiff = diff ** 2
    sumdiff = sqdiff.sum(axis=1)
    Distance  = sumdiff ** 0.5
    print('欧式距离:',Distance)
    sorteDistance = Distance.argsort()
    print('距离升序对应的下标数组:',sorteDistance)
    #建立存放标签的矩阵
    KLabel = np.zeros(k)
    class1=0
    class2=0
    print(KLabel)
    for i in range(k):
        nearLabel = y[sorteDistance[i]]
        print('最近样本的标签:',nearLabel)
        KLabel[i] = nearLabel
        #得到k个标签中各个类的数量
        if(nearLabel==-1):
            class1=class1+1
        if(nearLabel==1):
            class2=class2+1
    print("各类的数量:",class1," ",class2)
    if(class1>class2):
        result=-1
    else:
        result=1
    return result

#调用分类器,对测试数据集进行分类
testnum = test.shape[0]
for i in range(testnum):
    value=Classifier(test[i],train,label,2)
    print("样本 {0} 的分类结果为 {1}".format(i,value))

参考资料:

1.Scikit-learn

2.《机器学习实战》


猜你喜欢

转载自blog.csdn.net/attitude_yu/article/details/80227169