from numpy import *
from os import listdir
import operator
#k-近邻算法
def classify0(inX,group,label,k):
m = group.shape[0]
inVector = tile(inX,(m,1))-group
dubleInVector = inVector**2
sumDubleInVector = dubleInVector.sum(axis=1)
distances = sumDubleInVector**0.5
disIndex = distances.argsort()
labelCount = {}
for i in range(k):
labelX = label[disIndex[i]]
labelCount[labelX] = labelCount.get(labelX,0)+1
sortLabelCount = sorted(labelCount.items(),key=operator.itemgetter(1),reverse=True)
return sortLabelCount[0][0]
#将图片转化为矩阵,这里的图片采用文本格式存储
def img2vector(filename):
returnVector = zeros((1,1024))
fr = open(filename)
for i in range(32):
fileLine = fr.readline()
for j in range(32):
returnVector[0,32*i+j] = int(fileLine[j])
return returnVector
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/trainingDigits')
m = len(trainingFileList)
hwVector = zeros((m,1024))
for i in range(m):
fr = trainingFileList[i]
frName = fr.split('.')[0]
frNameIndex = frName.split('_')[0]
hwLabels.append(frNameIndex) #获得标签集
#获得训练集
hwVector[i,:] = img2vector('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/trainingDigits/%s' % fr)
#开始测试
testFileList = listdir('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/testDigits')
errorCount = 0
mTest = len(testFileList)
for i in range(mTest):
fr = testFileList[i]
frName = fr.split('.')[0]
frNameLabel = frName.split('_')[0]
inX = img2vector('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/testDigits/%s' % fr)
returnLabel = classify0(inX,hwVector,hwLabels,3)
print('the real result is %s,the test result is %s' % (returnLabel,frNameLabel))
if returnLabel != frNameLabel:
errorCount += 1
print('the total number of errors is %d ' % errorCount)
print('the total error rate is %f' % (errorCount/float(mTest)))
if __name__ == '__main__':
handwritingClassTest()
k-近邻算法实战2——识别手写数字
猜你喜欢
转载自blog.csdn.net/lwycc2333/article/details/81558301
今日推荐
周排行