高斯核函数是SVM中使用最多的一种核函数,对比高斯函数x-u,高斯核函数中表征的是两个向量(x,y)之间的关系,高斯函数又被称为RBF核和径向基核函数。在多项式核函数中,我们知道多项式核函数是将数据点添加多项式项,再将这些有了多项式项的特征点进行点乘,就形成了多项式核函数,对于高斯核函数也是一样,首先将原来的数据点映射成一种新的特征向量,然后得到新的特征向量点乘的结果,对高斯核函数来说,本质就是将每一个样本点映射到一个无穷维的特征空间,这就表明高斯核函数对于样本数据的变形是非常复杂的,但是经过变形,再去点乘,得到的结果却是非常简明的,就是核函数中的式子,这样也表明了核函数的威力。
回忆一下多项式特征,就是依靠升维使得原本线性不可分的数据变得线性可分。比如下图中,有一维数据线性不可分,但是添加多项式特征后就可以在二维空间中线性可分。
本质上,高斯核也在做类似的事情,为了方便可视化,对核函数进行改变,对于y值不取样本点,取固定的值,取两个固定的点l1,l2(landmark)。高斯核函数对于一维数据升维成二维点,这样我们就将一维的样本点映射到了二维空间,具体取值下图所示,并通过程序模拟是怎样通过这样一个高斯核函数将一维线性不可分的数据变得线性可分的。
import numpy as np import matplotlib.pyplot as plt #准备数据 #x一维向量[-4 -3 -2 -1 0 1 2 3 4];y[0 0 1 1 1 1 1 0 0] x=np.arange(-4,5,1) print(x) y=np.array((x>=-2)&(x<=2),dtype=int) print(y) plt.scatter(x[y==0],[0]*len(x[y==0])) plt.scatter(x[y==1],[0]*len(x[y==1])) plt.show()
由图可以看出,显然线性不可分,接着,通过高斯核函数映射到二维
#定义高斯核函数 def guess(x,l): gamma=1.0 return np.exp(-gamma*(x-l)**2) #landmark l1,l2=-1,1 #x_new,len(x)行,两列,第一列(np.exp(-gamma*(x-l1)**2)),第二列np.exp(-gamma*(x-l2)**2) x_new=np.empty((len(x),2)) for i,data in enumerate(x): x_new[i,0]=guess(data,l1) x_new[i,1]=guess(data,l2) plt.scatter(x_new[y==0,0],x_new[y==0,1]) plt.scatter(x_new[y==1,0],x_new[y==1,1]) plt.show()
对于这样一个二维数据,显然线性可分,如下图所示。
但在实际中,高斯核函数每个数据对每一个数据点都是landmark,成为新的高维数据中对应的每个元素。m*n的数据映射成了m*m的数据。高斯核函数的计算开销比较大,所以训练时间一般较长,但尽管如此,还是在有些领域非常适合使用高斯核函数,初始样本数据维度本身就比较高,但是样本数据数量不多,当M<N时,比较划算,例如自然语言处理。
import numpy as np import matplotlib.pyplot as plt from sklearn import datasets from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC #使用管道将StandardScaler和SVC连在一起 from sklearn.pipeline import Pipeline x,y=datasets.make_moons(noise=0.15,random_state=666) print(x) print(y) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.show() def RBFKernelSVC(gamma,C): return Pipeline([ ('std_scaler',StandardScaler()), #采用高斯核函数rbf # gamma越大,高斯图形越窄,模型复杂度越高,容易导致过拟合 # gamma越小,高斯图形越宽,模型复杂度越低,容易导致欠拟合 ('svc',SVC(kernel='rbf',gamma=gamma,C=C)) ]) svc1=RBFKernelSVC(0.1,1) svc1.fit(x,y) svc2=RBFKernelSVC(1,1) svc2.fit(x,y) svc3=RBFKernelSVC(10,1) svc3.fit(x,y) svc4=RBFKernelSVC(100,1) svc4.fit(x,y) svc5=RBFKernelSVC(0.1,5) svc5.fit(x,y) svc6=RBFKernelSVC(1,5) svc6.fit(x,y) svc7=RBFKernelSVC(10,5) svc7.fit(x,y) svc8=RBFKernelSVC(100,5) svc8.fit(x,y) def plot_decision_boundary(model,axis): x0,x1=np.meshgrid( np.linspace(axis[0],axis[1],int((axis[1]-axis[0])*100)), np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)) ) x_new=np.c_[x0.ravel(),x1.ravel()] y_predict=model.predict(x_new).reshape(x0.shape) from matplotlib.colors import ListedColormap # 自定义colormap custom_cmap=ListedColormap(['#EF9A9A','#FFF59D','#90CAF9']) plt.contourf(x0,x1,y_predict,linewidth=5,cmap=custom_cmap) flg=plt.figure() #flg.subplots_adjust(left=0.15,bottom=0.1,top=0.9,right=0.95,hspace=0.35,wspace=0.25) plt.subplot(2, 4, 1), plt.title('gamma=0.1,C=1') plot_decision_boundary(svc1,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4, 2), plt.title('gamma=1,C=1') plot_decision_boundary(svc2,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4, 3), plt.title('gamma=10,C=1') plot_decision_boundary(svc3,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4, 4), plt.title('gamma=100,C=1') plot_decision_boundary(svc4,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4, 5), plt.title('gamma=0.1,C=5') plot_decision_boundary(svc5,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4,6), plt.title('gamma=1,C=5') plot_decision_boundary(svc6,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4, 7), plt.title('gamma=10,C=5') plot_decision_boundary(svc7,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.subplot(2, 4, 8), plt.title('gamma=100,C=5') plot_decision_boundary(svc8,axis=[-1.5,2.5,-1.0,1.5]) plt.scatter(x[y==0,0],x[y==0,1]) plt.scatter(x[y==1,0],x[y==1,1]) plt.show()
gamma越大,高斯图形越窄,模型复杂度越高,容易导致过拟合
gamma越小,高斯图形越宽,模型复杂度越低,容易导致欠拟合