1. 分类模型 与 Loss 函数的定义
分类和回归问题,是监督学习的 2 大分支。不同点在于:分类问题的目标变量是离散的,而回归是连续的数值。本文讨论的是分类模型。
分类模型的例子: 根据年龄、性别、年收入等相互独立的特征,预测一个人的政治倾向(民主党、共和党、其他党派)。为了训练模型,必须先定义衡量模型好与坏的标准。在机器学习中,我们使用 loss / cost,即,当前模型与理想模型的差距。训练的目的,就是不断缩小 loss / cost.
这里为什么不能用 classification error?这里用一个实际的例子来看classification error的不足。使用3级训练数据,computed 一栏是预测结果,targets是预期结果。二者的数字,都可以理解为概率。correct 一栏表示预测是否正确。
模型1
computed | targets | correct?
------------------------------------------------
0.3 0.3 0.4 | 0 0 1 (democrat) | yes
0.3 0.4 0.3 | 0 1 0 (republican) | yes
0.1 0.2 0.7 | 1 0 0 (other) | no
item 1 和 2 以非常微弱的优势判断正确,item 3 则完全错误。
classification error = 1/3 = 0.33。
模型2
computed | targets | correct?
-------------------------------------------------
0.1 0.2 0.7 | 0 0 1 (democrat) | yes
0.1 0.7 0.2 | 0 1 0 (republican) | yes
0.3 0.4 0.3 | 1 0 0 (other) | no
item 1 和 2 的判断非常精准,item 3 判错,但比较轻。
classification error = 1/3 = 0.33。
从例子里可以看出:2 个模型的 classification error 相等,但模型 2 要明显优于模型 1。classification error 很难精确描述模型与理想模型之间的距离。
2. Cross-Entropy 的效果对比
Tensorflow 官网的 MNIST for ML Beginners 中 cross entropy 的计算公式是:
根据公式,第一个模型中三项的 cross-entropy 分别是:
-( (ln(0.3)*0) + (ln(0.3)*0) + (ln(0.4)*1) ) = -ln(0.4)
-( (ln(0.3)*0) + (ln(0.4)*1) + (ln(0.3)*0) ) = -ln(0.4)
-( (ln(0.1)*1) + (ln(0.2)*0) + (ln(0.7)*0) ) = -ln(0.1)
因此,第一个模型的ACE (average cross-entropy error)是
-(ln(0.4) + ln(0.4) + ln(0.1)) / 3 = 1.38
类似的,第二个模型的ACE是:
(ln(0.7) + ln(0.7) + ln(0.3)) / 3 = 0.64
因此,ACE的结果显示,模型2优于模型1。cross-entropy更清晰的描述了模型与理想模型的距离。
3. 为什么不用 Mean Squared Error (平方和)
若使用 MSE(mean squared error),第一个模型第一项的 loss 是
(0.3 - 0)^2 + (0.3 - 0)^2 + (0.4 - 1)^2 = 0.09 + 0.09 + 0.36 = 0.54
因此,第一个模型的 loss 是
(0.54 + 0.54 + 1.34) / 3 = 0.81
类似的,第二个模型的 loss 是
(0.14 + 0.14 + 0.74) / 3 = 0.34
看起来也是蛮不错的。为何不用?
这是因为:分类问题,最后必须是 one hot 形式算出各 label 的概率,然后通过 argmax 选出最终的分类。在计算各个 label 概率的时候,用的是 softmax 函数。
如果用MSE计算 loss,得到的曲线是波动的,有很多局部的极值点。即:非凸优化问题(non-convex)。cross entropy计算loss,则依旧是一个凸优化问题,用梯度下降法求解时,凸优化问题有很好的收敛特性。
小结
对于多分类问题,在 softmax 的加工下,神经网络的实际输出值是一个概率向量,如下图所示:
最好的方式是衡量二者间的概率分布差异,这就是 cross entropy,它的设计初衷,就是要衡量两个概率分布间的差异。但是,为什么分类要使用 softmax 函数呢?它能很好的模拟 max 行为,让“强者更强,弱者愈弱”。这个特性,对于分类来说尤为重要,能让学习的效率更高。在上图中,[4, 1, -2]中,4 和 1 的差距看起来并不大,但经过 softmax 渲染之后,前者的分类概率接近96%,而后者仅4%左右。这正是 softmax 的魅力所在。这样,分类标签可以看作是概率分布(由 one-hot 变换而来),神经网络输出(经过 softmax 加工)也是一个概率分布,现在想衡量二者的差异(即 loss),自然用 cross entropy 最好了。