前一篇关于kNN的博客介绍了算法的底层实现,这片博客让我们一起看一看基于scikit-learn如何快速的实现kNN算法。
scikit-learn内置了很多数据集,就不用我们自己编造假数据了,下面我们分别选用鸢尾花和手写数字识别的数据集。
首先导入需要的库
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
import numpy as np
获取鸢尾花数据集
iris = datasets.load_iris()
iris_x = iris.data
iris_y = iris.target
使用train_test_split对训练集和测试集进行划分,这里随机种子传了666
x_train,x_test,y_train,y_test = train_test_split(iris_x, iris_y, test_size=0.2, random_state = 666)
接下来初始化kNN分类器,并且进行fit操作,由于kNN算法比较简单,fit操作只是把训练数据填进去,复杂一点的fit函数里对于数据有更多的处理。
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
接下来便可以通过score方法查看分类的结果,也可以输出预测集进行查看
y_predict = knn.predict(x_test)
knn.score(x_test, y_test)
完整代码:
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
import numpy as np
#鸢尾花识别
iris = datasets.load_iris()
iris_x = iris.data
iris_y = iris.target
x_train,x_test,y_train,y_test = train_test_split(iris_x, iris_y, test_size=0.2, random_state = 666)
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
y_predict = knn.predict(x_test)
knn.score(x_test, y_test)
同样的配方,同样的味道,接下来看看手写数字识别,鸢尾花数据集有150组数据,而手写数字识别有2000多组可用数据。
from sklearn import datasets
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
import numpy as np
#手写体数字识别
digits = datasets.load_digits()
digits_x = digits.data
digits_y = digits.target
x_train,x_test,y_train,y_test = train_test_split(digits_x, digits_y, test_size=0.2, random_state = 666)
knn = KNeighborsClassifier()
knn.fit(x_train, y_train)
y_predict = knn.predict(x_test)
knn.score(x_test, y_test)
感情去的同学同样可以用其他的数据集来测试kNN算法。
接下来谈谈对于超参数的调试(炼丹炉)(超参数:在算法运行前需要决定的参数,与之对应的还有一个模型参数,kNN不涉及模型参数),超参数在机器学习里是很重要的一步,通常选取一个好的超参数和领域知识、实际经验、实验搜索等都着非常密切的关系。
尝试寻找最好的k
best_score = 0.0
best_k = -1
for k in range(1, 11) :
knn_clf = KNeighborsClassifier(n_neighbors=k)
knn_clf.fit(X_train, y_train)
score = knn_clf.score(x_test, y_test)
if score > best_score:
best_k = k
best_score = score
print("best_k = " + str(best_k))
print("best_score = " + str(best_score))
尝试在unform和distance下寻找最好的k
best_method = ""
best_score = 0.0
best_k = -1
for method in ["uniform", "distance"] :
for k in range(1, 11) :
knn_clf = KNeighborsClassifier(n_neighbors=k, weights=method)
knn_clf.fit(X_train, y_train)
score = knn_clf.score(x_test, y_test)
if score > best_score:
best_k = k
best_score = score
best_method = method
print("best_k = " + str(best_k))
print("best_score = " + str(best_score))
print("best_method = " + best_method)
尝试寻找最好的p
best_p = -1
best_score = 0.0
best_k = -1
for k in range(1, 11) :
for p in range(1, 6) :
knn_clf = KNeighborsClassifier(n_neighbors=k, weights="distance", p = p)
knn_clf.fit(X_train, y_train)
score = knn_clf.score(x_test, y_test)
if score > best_score:
best_k = k
best_score = score
best_p = p
print("best_k = " + str(best_k))
print("best_score = " + str(best_score))
print("best_p = " + str(best_p))
这几个参数的定义可以在KNeighborsClassifier分类器里查看到,超参数通常具有连续性,比如第一段代码寻找到的k是10的话,就需要在看看10到20的k会不会更好,以此类推。
sciket-learn的强大之处在于对于很多的操作,都进行了封装,对于寻找合适的超参数,sciket-learn提供了网格化搜索的方式,在使用的时候直接传入一组参数就可以了。
from sklearn.model_selection import GridSearchCV
params_grid = [
{
'weights':['uniform'],
'n_neighbors':[i for i in range(1, 11)]
},
{
'weights':['distance'],
'n_neighbors':[i for i in range(1, 11)],
'p':[i for i in range(1, 6)]
}
]
knn = KNeighborsClassifier()
grid_search = GridSearchCV(knn, params_grid)
grid_search.fit(X, y)