版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/apollo_miracle/article/details/89188646
k-近邻算法(kNN),它的工作原理是:存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分类的对应关系。输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据(最近邻)的分类标签。一般来说,我们只选择样本数据集中前k个最相似的数据,这就是k-近邻算法中k的出处,通常k是不大于20的整数。 最后,选择k个最相似数据中出现次数最多的分类,作为新数据的分类。
这里首先给出k-近邻算法的伪代码和实际的Python代码,然后详细地解释每行代码的含义。
该函数的功能是使用k-近邻算法将每组数据划分到某个类中,其伪代码如下:
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的k个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。
python代码:
from numpy import *
import operator
def create_data_set():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ["A", "A", "B", "B"]
return group, labels
def kNN(inX, data_set, labels, k):
# 动态计算数据集的行数
data_set_size = data_set.shape[0]
print("data_set_size:", data_set_size)
# tile(inX, (data_set_size, 1)) #在行方向上重复inX data_set_size次,列1次
# diff_mat为矩阵相减结果
diff_mat = tile(inX, (data_set_size, 1)) - data_set
print("diff_mat:", diff_mat)
# diff_mat矩阵内每一项求二次方
sqdiff_mat = diff_mat ** 2
print("sqdiff_mat:", sqdiff_mat)
# 矩阵每一行求和
distances_sq = sqdiff_mat.sum(axis=1)
print("distances_sq:", distances_sq)
# 开方
distances = distances_sq ** 0.5
print("distances:", distances)
# 按值的大小对索引升序排列
sorted_dist_index = distances.argsort()
print("sorted_dist_index:", sorted_dist_index)
class_count = {}
for i in range(k):
vote_label = labels[sorted_dist_index[i]]
print("vote_label:", vote_label)
class_count[vote_label] = class_count.get(vote_label, 0) + 1
print("class_count[vote_label]:", class_count[vote_label])
sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
print("sorted_class_count:", sorted_class_count)
return sorted_class_count[0][0]
if __name__ == '__main__':
group, labels = create_data_set()
kNN([0, 0], group, labels, 3)
代码结果:
data_set_size: 4
diff_mat: [[-1. -1.1]
[-1. -1. ]
[ 0. 0. ]
[ 0. -0.1]]
sqdiff_mat: [[1. 1.21]
[1. 1. ]
[0. 0. ]
[0. 0.01]]
distances_sq: [2.21 2. 0. 0.01]
distances: [1.48660687 1.41421356 0. 0.1 ]
sorted_dist_index: [2 3 1 0]
vote_label: B
class_count[vote_label]: 1
vote_label: B
class_count[vote_label]: 2
vote_label: A
class_count[vote_label]: 1
sorted_class_count: [('B', 2), ('A', 1)]