一:理论部分
给定一个样本集,每个样本点有两个维度值(X1,X2)和一个类别值,类别只有两类,我们以0和1代表。数据如下所示:
样本 |
X1 |
X2 |
类别 |
1 |
-1.4 |
4.7 |
1 |
2 |
-2.5 |
6.9 |
0 |
... |
... |
... |
... |
机器学习的任务是找一个函数,给定一个数据两个维度的值,该函数能够预测其属于类别1的概率。
假设这个函数的模样如下:
h(x) =sigmoid(z)
z = w0 +w1*X1+w2*X2
问题转化成了,根据现有的样本数据,找出最佳的参数w(w0,w1,w2)的值
为进一步简化问题,我们假设样本集只有上表中的两个。
假设现在手上已经有一个wt,也就是有了一个函数h(x),那么我们可以把样本1和样本2的数据代进去,看看这个函数的预测效果如何,假设样本1的预测值是p1 = 0.8,样本2的预测值是:p2 = 0.4。
函数在样本1上犯的错误为e1=(1-0.8)= 0.2,在样本2上犯的错误为e2=(0-0.4)= -0.4,总的错误E为-0.20(e1+e2)。如下表所示:
样本 |
X1 |
X2 |
类别 |
预测值 |
error |
1 |
-1.4 |
4.7 |
1 |
0.8 |
0.2 |
2 |
-2.5 |
6.9 |
0 |
0.4 |
-0.4 |
... |
... |
... |
... |
... |
... |
现在我们要改进wt的值,使得函数在样本1和2上犯的总错误E减小。
将wt的改进拆开来,就是分别改进它的三个分量的值,我们以w1为例。
对于样本1:
X1*e1=-1.4*0.2= -0.28
-0.28告诉我们什么呢?它告诉我们,样本1的X1和e1是异号的,减小w1的值,能够减小函数在样本1上犯的错误。为什么呢?
w1减小,则X1*w1增大(因为样本1的X1是负的),进而 z = w0 +w1*X1+w2*X2增大,又由于sigmoid函数是单调递增的,则h(x)会增大。当前的h(x)是0.8,增大的话就是在向1靠近,也就是减小了在样本1上犯的错。
对于样本2:
X1*e2=-2.5*-0.4= 1
1告诉我们,样本2的X1和e2是同号的,增大w1的值,能够减小函数在样本2上犯的错误。为什么呢?
w1增大,则X1*w1减小,进而 z = w0 +w1*X1+w2*X2减小,又由于sigmoid函数是单调递增的,则h(x)会减小。当前的h(x)是0.4,减小的话就是在向0靠近,也就是减小了在样本2上犯的错。
现在的问题就是这样的,样本1说,要减小w1的值,这样函数对我的判断就更准确了,样本2说,要增大w1的值,这样函数对我的判断就更准确了。
显然,样本1和样本2都只从自己的角度出发,对改进w1提出了各自不同意见,我们要综合它们的意见,以决定是增大w1还是减小w1,如下:
-0.28+1 = 0.72
最后的结果0.72是正的,说明,增大w1对函数的总体表现更有利。就是说,增大w1后,虽然在样本1上犯的错误会稍稍增大,但在样本2上犯的错误会大大减小,一个是稍稍增大,一个是大大减小,为了函数总体表现,肯定是增大w1的值啦。
那么具体增加多大呢?我们可以用一个专门的参数alpha来控制。
二 Python代码(核心部分)
from numpy import *
def gradientAscent(dataMat,labelMat):
dataMat=mat(dataMat)
m,n=shape(dataMat)
labelMat=mat(labelMat).T
#假设weight =1
weights=ones((n,1))
alpha=0.00001 #学习率
num=500000 #循环次数
for k in range(num):
#整个数据集全部运算:num*m
#计算 z 值
z=dataMat*weights
y=sigmoid(z)
error=labelMat-y
#更新weights
weights=weights+alpha*dataMat.T * error
return weights
三.画出图像
def showPlot(weights):
import matplotlib.pyplot as plt
dataMat,labelMat=loadDataSet()
dataArr=array(dataMat)
n=shape(dataArr)[0]
#正样本
xcord1=[]
ycord1=[]
xcord2=[]
ycord2=[]
#循环数据,存到正负样本中
for i in range(n):
if int(labelMat[i])==1:
xcord1.append(dataArr[i,1])
ycord1.append(dataArr[i,2])
else:
xcord2.append(dataArr[i,1])
ycord2.append(dataArr[i,2])
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(xcord1,ycord1,c='red')
ax.scatter(xcord2,ycord2,c='green')
x=arange(-5,5,0.1)
y=arange(-5,5,0.1)
y=(-weights[0]-weights[1]*x)/weights[2]
y=y.T
ax.plot(x,y)
plt.show()
四.其余代码
def loadDataSet():
dataMat=[]
labelMat=[]
fr=open('dataset/testSet.txt')
for line in fr.readlines():
array=line.strip().split()
#截距 特征
dataMat.append([1.0,float(array[0]),float(array[1])])
labelMat.append(int(array[2]))
return dataMat,labelMat
#方式一的函数图
dataMat,labelMat=loadDataSet()
weights1=gradientAscent(dataMat,labelMat)
showPlot(weights1)
五.数据集
-0.017612 14.053064 0
-1.395634 4.662541 1
-0.752157 6.538620 0
-1.322371 7.152853 0
0.423363 11.054677 0
0.406704 7.067335 1
0.667394 12.741452 0
-2.460150 6.866805 1
0.569411 9.548755 0
-0.026632 10.427743 0
0.850433 6.920334 1
1.347183 13.175500 0
1.176813 3.167020 1
-1.781871 9.097953 0
-0.566606 5.749003 1
0.931635 1.589505 1
-0.024205 6.151823 1
-0.036453 2.690988 1
-0.196949 0.444165 1
1.014459 5.754399 1
1.985298 3.230619 1
-1.693453 -0.557540 1
-0.576525 11.778922 0
-0.346811 -1.678730 1
-2.124484 2.672471 1
1.217916 9.597015 0
-0.733928 9.098687 0
-3.642001 -1.618087 1
0.315985 3.523953 1
1.416614 9.619232 0
-0.386323 3.989286 1
0.556921 8.294984 1
1.224863 11.587360 0
-1.347803 -2.406051 1
1.196604 4.951851 1
0.275221 9.543647 0
0.470575 9.332488 0
-1.889567 9.542662 0
-1.527893 12.150579 0
-1.185247 11.309318 0
-0.445678 3.297303 1
1.042222 6.105155 1
-0.618787 10.320986 0
1.152083 0.548467 1
0.828534 2.676045 1
-1.237728 10.549033 0
-0.683565 -2.166125 1
0.229456 5.921938 1
-0.959885 11.555336 0
0.492911 10.993324 0
0.184992 8.721488 0
-0.355715 10.325976 0
-0.397822 8.058397 0
0.824839 13.730343 0
1.507278 5.027866 1
0.099671 6.835839 1
-0.344008 10.717485 0
1.785928 7.718645 1
-0.918801 11.560217 0
-0.364009 4.747300 1
-0.841722 4.119083 1
0.490426 1.960539 1
-0.007194 9.075792 0
0.356107 12.447863 0
0.342578 12.281162 0
-0.810823 -1.466018 1
2.530777 6.476801 1
1.296683 11.607559 0
0.475487 12.040035 0
-0.783277 11.009725 0
0.074798 11.023650 0
-1.337472 0.468339 1
-0.102781 13.763651 0
-0.147324 2.874846 1
0.518389 9.887035 0
1.015399 7.571882 0
-1.658086 -0.027255 1
1.319944 2.171228 1
2.056216 5.019981 1
-0.851633 4.375691 1
-1.510047 6.061992 0
-1.076637 -3.181888 1
1.821096 10.283990 0
3.010150 8.401766 1
-1.099458 1.688274 1
-0.834872 -1.733869 1
-0.846637 3.849075 1
1.400102 12.628781 0
1.752842 5.468166 1
0.078557 0.059736 1
0.089392 -0.715300 1
1.825662 12.693808 0
0.197445 9.744638 0
0.126117 0.922311 1
-0.679797 1.220530 1
0.677983 2.556666 1
0.761349 10.693862 0
-2.168791 0.143632 1
1.388610 9.341997 0
0.317029 14.739025 0