KNN基础代码三

增进点:训练多个 KNN 模型,可视化
库:
matplotlib,numpy,itertools(里面有个 product 可视化),sklearn

数据准备过程

n_points = 100
X1 = np.random.multivariate_normal( [1,50] , [[1,0],[0,10]] , n_points)
X2 = np.random.multivariate_normal( [2,50] , [[1,0],[0,10]] , n_points)

生成随机样本,个数100
np.random.multivariate_normal( mean , cov [, size] )
参数:(平均值协方差,个数)
X1,X2 是数组,一个大数组,里面是 100个 1行*2列 的列表 ( 100行 ✖️2 列)

X = np.concatenate( [X1,X2] )

np.concatenate( [ 数组1 , 数组2 ] ) ,把 数组2 追加到 数组1 后面,组合成1个数组,这里X个数是200 ,即 100+100

y = np.array([0]*n_points + [1]*n_points)

强行组了100个0,100个1

print (X.shape, y.shape)

X.shape = (200, 2),y.shape = (200,)

KNN 模型训练过程

clfs = []
neighbors = [1,3,5,9,11,13,15,17,19]
for i in range ( len (neighbors) ):
clfs.append( KNeighborsClassifier (n_neighbors=neighbors[I]). fit(X,y))

上上节KNN基础代码一,有用过。
KNeighborsClassifier 这个是 KNN 核心调用

可视化结果

KNN基础代码一,整体代码没写

import matplotlib.pyplot as plt
import numpy as np
from itertools import product
from sklearn.neighbors import KNeighborsClassifier

#生成一些随机样本
n_points = 100
X1 = np.random.multivariate_normal([1,50], [[1,0],[0,10]], n_points)
X2 = np.random.multivariate_normal([2,50], [[1,0],[0,10]], n_points)
X = np.concatenate([X1,X2])
y = np.array([0]*n_points + [1]*n_points)
print (X.shape, y.shape)

#KNN模型的训练过程
clfs = []
neighbors = [1,3,5,9,11,13,15,17,19]
for i in range(len(neighbors)):
clfs.append(KNeighborsClassifier(n_neighbors=neighbors[i]).fit(X,y))

#可视化结果
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),np.arange(y_min, y_max, 0.1))
f, axarr = plt.subplots(3, 3, sharex=‘col’, sharey=‘row’, figsize=(15, 12))
for idx, clf, tt in zip(product([0, 1, 2], [0, 1, 2]),clfs,[‘KNN (k=%d)’ % k for k in neighbors]):
       Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
       Z = Z.reshape(xx.shape)
       axarr[idx[0], idx[1]].contourf(xx, yy, Z, alpha=0.4)
       axarr[idx[0], idx[1]].scatter(X[:, 0], X[:, 1], c=y,s=20, edgecolor=‘k’)
       axarr[idx[0], idx[1]].set_title(tt)
plt.show()

发布了53 篇原创文章 · 获赞 4 · 访问量 8788

猜你喜欢

转载自blog.csdn.net/lee__7/article/details/102783853
kNN