引言
支持向量机(SVM,Support Vector Machine)是一种分类算法,其基本思想是在样本空间中找到一个超平面,在将不同类别的样本分开的前提下,使超平面离距自己最近的样本尽可能远。
如上图所示二维空间中,支持向量机算法的目标就是找到右边的黑色实线所代表的超平面,在成功分类样本的前提下,使得两条虚线之间的间隔最大,以获得最好的泛化能力。落在虚线上的样本点被称为支持向量。
算法原理
超平面模型
我们需要找到一个超平面来划分样本,那么我们可以先用以下线性方程描述超平面:
其中 是法向量,决定了超平面的方向;b是位移项,决定了超平面与原点之间的距离。
点到超平面距离公式
由于算法涉及到超平面与样本点的距离问题,现在我们推导样本空间中任意点 到超平面 之间的距离。
上图中的点 是样本空间中任意点, 与 是超平面上的点, 是超平面的法向量。点 到平面的距离可以转化为点 到平面任意一点 的向量在法向量 上的投影。
根据向量点积的几何定义:
可得出 在 上的投影:
同理我们可以得出向量 在法向量 上的投影为:
又由于点 是平面上点,满足 ,因此上式可改为:
目标函数
如果超平面能够正确地将样本分类,那么决策方程 可写为:
即当 是正例时,样本点在超平面上方;当 是负例是,样本点在超平面下方。
如上方程组仍有些复杂,因此我们可以将 乘以 简化方程:
为了方便后续计算,在这里我们可以通过对 的放缩变换使得方程右边为1:
结合 ,可得:
由于 恒大于0,因此原式 中的绝对值脱掉了。
我们的优化目标是让超平面距离最近的样本点越远越好:
由于 ,因此我们只需考虑:
是一个求解极大值的问题,可以等价于:
至此,我们得到了支持向量机的基本形:
即求出在约束条件 下,目标函数 的极小值点。
带约束条件的极值问题求解
上面的式子 中的目标是很典型的条件极值问题,我们可以采用拉格朗日乘子法来解决。
构造拉格朗日函数:
令 对 和 的偏导为0可得:
将 代入 ,可得:
至此我们得到了 的对偶问题:
再将减号两边对调,把问题转化为求极小值问题,另外别忘了KKT条件:
上式中的 、 都是已知量,代入求得 后,求出 和 即可得到超平面:
手算简单案例
如上图所示,红色的圆代表正例,黑色的叉叉代表负例,现在我们要根据支持向量机的原理手算超平面。
将样本数据代入至 :
根据约束条件:
有:
代入消元可得:
分别对 和 求偏导,令偏导等于0:
解得:
该结果并不满足 中的约束 ,那么解应该在边界上。
首先试试 ,代回 可得:
然后试试 ,代回 可得:
然后分别把以上两组 值代回 ,发现第二组计算出的最终结果最小。
因此可以确定:
与 不为0,因此 与 是在最大间隔边界上的支持向量。
最后我们把 的值带入 求解 :
将支持向量
或者
带入超平面模型,求得
:
因此求得的超平面为:
图像效果如图所示:
红色的实线是超平面,黑色的虚线是最大间隔边界。容易看出,超平面的位置只与支持向量有关,与其他样本点无关。
核函数
之前的例子展现了SVM的一般计算步骤,但是像该例这样线性可分的样本可以看做是特例,更多情况下的样本集是非线性可分的。这时我们就需要用某种变化,将不可分变为可分,类似于下图:
常用核函数:
- 多项式核:
- 高斯核:
- 线性核:
手写SVM算法
实现的有些烂,后面换种思路。。
import numpy as np
import pandas as pd
import sympy as sp
from itertools import combinations
import sys
import matplotlib.pyplot as plt
class SVM:
def __init__(self,pdData):
datas = pdData.values
flag,a=self.step1(datas)
if not flag:
a=self.step2(datas)
self.step3(datas,a)
def step1(self,datas):
m=datas.shape[0]
self.objective_function=sp.symbols('tmp')
constraint=sp.symbols('tmp')
for i in range(1,m+1):
for j in range(1,m+1):
a_i=sp.symbols('a'+str(i))
a_j=sp.symbols('a'+str(j))
y_i=datas[i-1,datas.shape[1]-1]
y_j=datas[j-1,datas.shape[1]-1]
x_i=datas[i-1,0:datas.shape[1]-1]
x_j=datas[j-1,0:datas.shape[1]-1]
self.objective_function=self.objective_function+(1/2)*(a_i*a_j*y_i*y_j*np.dot(x_i.T,x_j))
self.objective_function=self.objective_function-sp.symbols('a'+str(i))
constraint=constraint+sp.symbols('a'+str(i))*datas[i-1,datas.shape[1]-1]
self.objective_function=self.objective_function-sp.symbols('tmp')
constraint=constraint-sp.symbols('tmp')
print('代入样本数据:'+str(self.objective_function))
print('有约束条件:'+str(constraint))
self.a1=sp.solve(constraint,sp.symbols('a1'))[0]
print('a1='+str(self.a1))
self.objective_function=self.objective_function.subs({'a1':self.a1})
print('代入换元:'+str(self.objective_function))
# 对a分别求偏导
da={}
for i in range(m-1):
da_i=sp.diff(self.objective_function,sp.symbols('a'+str(i+2)))
print('对a'+str(i+2)+'求偏导:'+str(da_i))
da[sp.symbols('a'+str(i+2))]=da_i
print('联立求解')
result=sp.solve(list(da.values()),list(da))
if result ==[]:
print('a无解')
return False,[]
result[sp.symbols('a1')]=self.a1.subs(result)
print(result)
# 判断是否符合约束
bol=True
for key in result:
if result[key]<0:
bol=False
break
if not bol:
print('不满足约束a>=0,解应在边界上')
return False,[]
print('a的解满足条件')
return True,result
def step2(self,datas):
m=datas.shape[0]
li=[]
alist=[]
for i in range(1,m):
alist.append(sp.symbols('a'+str(i+1)))
for i in range(1,m-1):
arrs=list(combinations(alist,i))
for arr in arrs:
print('令',end='')
subs={}
for a_i in arr:
print(a_i,end='')
subs[a_i]=0
print('等于0')
tmp=self.objective_function.subs(subs)
print('原式变为:'+str(tmp))
dda={}
for i in range(m-1):
da_i=sp.diff(tmp,sp.symbols('a'+str(i+2)))
print('对a'+str(i+2)+'求偏导:'+str(da_i))
dda[sp.symbols('a'+str(i+2))]=da_i
print('联立求解')
dic=sp.solve(list(dda.values()),list(dda))
if dic ==[]:
print('无解')
continue
dic.update(subs)
dic[sp.symbols('a1')]=self.a1.subs(dic)
if (np.array(list(dic.values()))==0).all():
continue
print(dic)
li.append(dic)
_index=-1
_min_value=sys.maxsize
print('原式:'+str(self.objective_function))
for i in range(len(li)):
print(li[i])
bol = True
for key in li[i]:
if (not type(li[i][key])==float and not type(li[i][key])==int) or li[i][key]<0:
bol=False
break
if not bol:
continue
result=self.objective_function.subs(li[i])
if result<_min_value:
_min_value=result
_index=i
a=li[_index]
print('计算a的结果:'+str(a))
return a
def step3(self,datas,a):
# 求解w
self.W=0
vec_index=-1
for i in range(len(a)):
a_i=a[sp.symbols('a'+str(i+1))]
y_i=datas[i,datas.shape[1]-1]
x_i=datas[i,0:datas.shape[1]-1]
self.W=self.W+a_i*y_i*x_i.T
if not a_i == 0 and vec_index == -1:
vec_index=i
# 求解b
x=datas[vec_index,0:datas.shape[1]-1]
y=datas[vec_index,datas.shape[1]-1]
self.b=y-np.dot(self.W.T,x)
print('最优w参数:'+str(self.W))
print('最优b参数:'+str(self.b))
def __model(self,W,b,X):
return np.dot(W,X)+b
def model(self,X):
exp=sp.symbols('tmp')
for i in range(len(self.W)):
exp=exp+self.W[i]*sp.symbols('x'+str(i+1))
exp=exp-sp.symbols('tmp')
exp=exp+self.b
exp=sp.solve(exp,sp.symbols('x2'))[0]
result=[]
for x in X:
result.append(exp.subs({'x1':x}))
return result
def classify(self,X):
result=[]
for item in X:
if self.__model(self.W,self.b,item) >0:
result.append(1)
elif self.__model(self.W,self.b,item) <0:
result.append(-1)
elif self.__model(self.W,self.b,item) ==0:
result.append(0)
return result
数据集:
c1 | c2 | class |
---|---|---|
34.62365962451697 | 78.0246928153624 | -1 |
30.28671076822607 | 43.89499752400101 | -1 |
35.84740876993872 | 72.90219802708364 | -1 |
60.18259938620976 | 86.30855209546826 | 1 |
79.0327360507101 | 75.3443764369103 | 1 |
45.08327747668339 | 56.3163717815305 | -1 |
61.10666453684766 | 96.51142588489624 | 1 |
75.02474556738889 | 46.55401354116538 | 1 |
76.09878670226257 | 87.42056971926803 | 1 |
84.43281996120035 | 43.53339331072109 | 1 |
测试代码:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from svm import SVM
def gaussian(x1,x2,sigma):
return np.exp(-((x1-x2)**2)/(1/2*sigma**2))
if __name__ == "__main__":
pdData=pd.read_csv("LogiReg_data.csv")
p=pdData[pdData['class']==1].values # 正例
n=pdData[pdData['class']==-1].values # 负例
plt.scatter(p[:,0],p[:,1]) # 画出正例散点图
plt.scatter(n[:,0],n[:,1]) # 画出负例散点图
svm=SVM(pdData) # 求解SVM
x=np.arange(0,100,0.01)
plt.plot(x,svm.model(x)) # 画出超平面
plt.show()
运行结果: