二分类与多分类
多分类问题的交叉熵损失函数推导与二分类交叉熵损失函数推导的思想是一致的。可以先参考 二分类问题的交叉熵损失函数推导可参考 二分类问题的交叉熵损失函数推导。
本文参考deeplearning 一书 6.2.2.3 Softmax Units for Multinoulli Output Distributions
好啦,开始步入正题啦~
首先要明确一点的是,为多分类问题和二分类问题设计的模型 在 最后一层直接输出的东西那里有比较大的不同。
假设模型的输出层是全连接层,我们可以画出二分类问题和多分类问题经过最后一层直接输出的内容是什么,如下图:(图中的多分类是三分类问题)
图中一个蓝色的圆圈代表一个实数值。对于二分类来讲,只需要输出一个实数值,再经过sigmoid函数转化之后就可表示模型预测为 1 的概率。
对于N分类问题,需要N个输出值,第N个输出值经过softmax转化之后可表示模型预测为 第N类 的概率。
经softmax函数转化之后可表示概率
下面推导为什么经softmax函数转化之后可以表示概率,现在开始推导啦~
第一步: 假设输出的那三(拓展到N)个蓝色圆圈组成一个向量 z。z 的表达式可以写为 。其中z是一个向量。
第二步:构造,(注意此时的还不是一个真实的概率分布)。使得,也就是。
为什么要 ?用 不行嘛?之所以用 是因为求交叉熵损失函数的时候,要放到 里边 ,这样在用梯度求解的时候不容易 saturate (原文是这个单词,我现在还不知道翻译成什么比较好。)这是梯度求解时候的问题,现在可不用理解。
第三步:获得概率分布。对 z 进行指数化exponentiate和归一化normalize。可以得到:
。
第四步:至此,由 z 转化成了 模型预测为各个 的概率,这个过程叫softmax函数,也就是我们熟悉的这个式子:
第五步:总结。对 z 进行exponentiate和normalize,就可以得到模型预测为各个 的概率。
经过softmax转化后,交叉熵损失函数推导
经过softmax转化之后,求交叉熵损失函数就比较容易啦。
交叉熵损失函数的公式是 。
(这里的 y 是这个样本的真实标签,假设有1,2,3类, 对应 、、 ,真实标签是 的话, 就是)。
将 代入公式 里,可得交叉熵损失函数:
如果实际的类别对应的是输出值,(假设只有1,2,3类,对应 、、)。那么损失函数就是。
到此我们的推导就结束啦,有不正确的地方欢迎各位大佬留言~