K近邻(K-Nearest Neighbor,KNN)算法是一种基本分类与回归方法,也是最简单的机器学习方法之一,这里只对K近邻算法的分类问题做总结。
K近邻算法简单、直观,它的工作原理是:给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最近邻的\(k\)个实例,这\(k\)个实例的多数属于某个类,就把该输入实例分为这个类。
1 K近邻模型的基本要素
K近邻模型的三个基本要素——距离度量、\(k\)值的选择和分类决策规则
1.1 距离度量
特征空间中两个实例点的距离反映两个实例点的相似程度。\(k\)近邻模型的特征空间一般是\(n\)维实数向量空间\(R^n\),使用的距离是欧式距离,但也可以是其他距离。
在特征空间中取两个特征\(x_i\),\(x_j\),它们都是\(n\)维向量。\(x_i\),\(x_j\)的\(L_p\)距离定义为
其中\(p≥1\)。当\(p=2\)时,称为欧式距离
当\(p=1\)时,称为曼哈顿距离
1.2 \(k\)值的选择
1.3 分类决策规则
2 K近邻算法的代数方式描述
输入:训练数据集
其中,
为实例的特征向量,
为实例的类别,
;实例特征向量\(x\);
输出:实例\(x\)所属的类\(y\).
(1)根据给定的距离度量方式,在训练集T中找到与\(x\)最邻近的\(k\)个点,涵盖着\(k\)个点的\(x\)的领域记作
;
(2)在
中根据分类决策规则决定\(x\)的类别\(y\):
其中\(I\)为指示函数,即当
时\(I\)为1,否则\(i\)为0
3 K近邻算法的代码实现
3.1 准备:使用Python导入数据
from numpy import* #NumPy科学计算包
import operator #运算符模块
#创建数据集和标签
def createDataSet():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels
group, labels = createDataSet()
print(group, labels)
#输出
[[ 1. 1.1]
[ 1. 1. ]
[ 0. 0. ]
[ 0. 0.1]] ['A', 'A', 'B', 'B']
3.2 实施KNN算法
对未知类别属性的数据集中的每个点依次执行以下操作:
(1) 计算已知类别数据集中的点与当前点之间的距离;
(2) 按照距离递增次序排序;
(3) 选取与当前点距离最小的k个点;
(4) 确定前k个点所在类别的出现频率;
(5) 返回前k个点出现频率最高的类别作为当前点的预测分类。
def classify0(inX, dataSet, labels, k):
#计算距离
dataSetSize = dataSet.shape[0] #获取训练数据的行数
diffMat = tile(inX, (dataSetSize, 1)) - dataSet # 现将测试数据的行升维使测试数据和训练数据维度相同 再相减得到向量
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1) #得到平方后的每个数组内元素的和
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort() #将距离按升序排列 argsort()该函数返回的是数组元素的下标
#选择距离最小的k个点 确定前k个距离最小元素所在的主要分类
classCount = {} #为了直观的看到不同数据类别的出现次数,设置一个空字典 最终以元组列表存放数据
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]] #获取数据类型
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 #统计每个数据类型的出现次数
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True) #将字典中的元素按出现次数降序排列 即itemgetter对元组第二个元素排序
return sortedClassCount[0][0] #返回出现次数最多的数据类型
classify0([0,0], group, labels, 3) #B
4 K近邻算法示例
4.1 在约会网站上使用k-近邻算法
4.1.1 准备数据:从文本文件中解析数据
#将文本记录转换为Numpy的解析程序
def file2matrix(filename):
fr = open(filename)
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines) #文件行数
returnMat = zeros((numberOfLines, 3)) #创建以0填充的NumPy矩阵
classLabelVector = []
index = 0
for line in arrayOLines: #解析文件数据到列表 循环处理文件中的每行数据
line = line.strip() #截取掉所有的回车字符
listFromLine = line.split('\t') #用tab字符\t将上一步得到的整行数据分割成一个元素列表
returnMat[index,:] = listFromLine[0:3] #选取前3个元素 将它们存储到特征矩阵中
classLabelVector.append(int(listFromLine[-1])) #将列表的最后一列存储到向量classLabelVector中 其中指明列表中存储的元素值为整型
index += 1
return returnMat, classLabelVector
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
datingDataMat
#输出
[[ 4.09200000e+04 8.32697600e+00 9.53952000e-01]
[ 1.44880000e+04 7.15346900e+00 1.67390400e+00]
[ 2.60520000e+04 1.44187100e+00 8.05124000e-01]
...,
[ 2.65750000e+04 1.06501020e+01 8.66627000e-01]
[ 4.81110000e+04 9.13452800e+00 7.28045000e-01]
[ 4.37570000e+04 7.88260100e+00 1.33244600e+00]]
datingLabels[0:20]
#输出
[3, 2, 1, 1, 1, 1, 3, 3, 1, 3, 1, 1, 2, 1, 1, 1, 1, 1, 2, 3]
4.1.2 分析数据:使用Matplotlib创建散点图
import matplotlib
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus']=False
fig = plt.figure()
ax = fig.add_subplot(111)
#ax.scatter(datingDataMat[:, 1], datingDataMat[:, 2]) #datingDataMat矩阵的第二 第三列数据
ax.scatter(datingDataMat[:,1], datingDataMat[:, 2], 15.0*array(datingLabels), 15.0*array(datingLabels)) #彩色
plt.xlabel('玩视频游戏所耗时间百分比')
plt.ylabel('每周消费的冰淇淋公升数')
plt.show()
4.1.3 准备数据: 归一化数值
#归一化特征值
def autoNorm(dataSet):
minVals = dataSet.min(0) #选取每列最小值
maxVals = dataSet.max(0) #选取每列最大值
ranges = maxVals - minVals #计算可能的取值范围
normDataSet = zeros(shape(dataSet)) #创建返回矩阵
m = dataSet.shape[0] #注:特征矩阵1000*3而minVals ranges值都是1*3 使用NumPy中的tile()将变量内容复制成输入矩阵同样大小的矩阵
normDataSet = dataSet - tile(minVals, (m,1)) #当前值减最小值
normDataSet = normDataSet/tile(ranges, (m,1)) #除以取值范围 得归一化特征值
return normDataSet, ranges, minVals
normMat, ranges, minVals = autoNorm(datingDataMat)
normMat, ranges, minVals
#输出
normMat:
[[ 0.44832535 0.39805139 0.56233353]
[ 0.15873259 0.34195467 0.98724416]
[ 0.28542943 0.06892523 0.47449629]
...,
[ 0.29115949 0.50910294 0.51079493]
[ 0.52711097 0.43665451 0.4290048 ]
[ 0.47940793 0.3768091 0.78571804]]
ranges:
[ 9.12730000e+04 2.09193490e+01 1.69436100e+00]
minVals:
[ 0. 0. 0.001156]
4.1.4 测试算法:作为完整程序验证分类器
#分类器针对约会网站的测试
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m*hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m,:],
datingLabels[numTestVecs:m], 3)
print("the classifier came back with: %d, the real answer is: %d" %(classifierResult, datingLabels[i]))
if(classifierResult != datingLabels[i]):
errorCount += 1.0
print("the total error rate is: %f"%(errorCount/float(numTestVecs)))
datingClassTest()
#输出
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 1, the real answer is: 1
the classifier came back with: 1, the real answer is: 1
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 1, the real answer is: 1
the classifier came back with: 2, the real answer is: 2
the classifier came back with: 3, the real answer is: 3
the classifier came back with: 3, the real answer is: 1
the total error rate is: 0.030000
#分类器处理约会数据集的错误率是3%
4.1.5 使用算法:构建完整可用系统系统
#约会网站预测函数
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(input("percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per years?"))
iceCream = float(input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats, iceCream])
classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
print("You will probably like this person:", resultList[classifierResult-1])
classifyPerson()
#输出
percentage of time spent playing video games?10
frequent flier miles earned per years?10000
liters of ice cream consumed per year?0.5
You will probably like this person: in small doses
4.2 使用k-近邻算法的手写识别系统
构造的系统只能识别数字0到9。需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小①:宽高是32像素×32像素的黑白图像。尽管采用文本格式存储图像不能有效地利用内
存空间,但是为了方便理解,将图像转换为文本格式。
4.2.1 准备数据:将图像转换为测试向量
trainingDigits中包含了大约2000个例子,每个数字大约有200个样本;testDigits中包含了大约900个测试数据。trainingDigits中的数据训练分类器,使用testDigits中的数据测试分类器
的效果。两组数据没有重叠。
为了使用前面两个例子的分类器,必须将图像格式化处理为一个向量。把一个32×32的二进制图像矩阵转换为1×1024的向量,这样前两节使用的分类器就可以处理数字图像信息了。
将图像转换为向量:该函数创建1×1024的NumPy数组,然后打开给定的文件,循环读出文件的前32行,并将每行的头32个字符值存储在NumPy数组中,最后返回数组。
def img2vector(filename):
f = open(filename)
returnVect = zeros((1,1024))
for i in range(32):
line = f.readline()
for j in range(32):
returnVect[0,i*32+j] = int(line[j])
return returnVect
testVector = img2vector('testDigits/0_13.txt')
testVector[0,0:31]
#输出
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1.
0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
testVector[0,32:63]
#输出
[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1.
1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
4.2.2 测试算法:使用k近邻算法识别手写数字
def handwritingClassTest():
fileList = os.listdir('trainingDigits') #获取目录内容
m = len(fileList)
traingMat = zeros((m, 1024))
hwlabels = []
for i in range(m): #从文件名解析分类数字
fileName = fileList[i]
prefix = fileName.split('.')[0]
number = int(prefix.split('_')[0])
hwlabels.append(number)
traingMat[i,:] = img2vector('trainingDigits/%s' %fileName)
testFileList = os.listdir('testDigits')
m = len(testFileList)
errorNum = 0.0
for i in range(m):
testFileName = testFileList[i]
prefix = testFileList[i].split('.')[0]
realNumber = int(prefix.split('_')[0])
testMat = img2vector('testDigits/%s' %testFileName)
testResult = classify0(testMat, traingMat, hwlabels, 3)
if testResult != realNumber:
errorNum += 1
print('The classifier came back with: %d, the real answer is: %d' %(testResult, realNumber))
print("\nthe total number of errors is: %d" % errorNum)
print('\nthe total error rate is %f' %(errorNum/float(m)))
handwritingClassTest()
#输出
The classifier came back with: 0, the real answer is: 0
The classifier came back with: 0, the real answer is: 0
.
.
The classifier came back with: 7, the real answer is: 7
The classifier came back with: 7, the real answer is: 7
The classifier came back with: 8, the real answer is: 8
The classifier came back with: 8, the real answer is: 8
The classifier came back with: 8, the real answer is: 8
.
.
The classifier came back with: 9, the real answer is: 9
The classifier came back with: 9, the real answer is: 9
the total number of errors is: 10
the total error rate is 0.010571
#K近邻算法识别手写数字数据集 错误率1.01%