什么是softmax
softmax不同于sigmoid函数,用softmax后能够将有正有负的输出化为和为1的正数输出,这些输出相互影响,可以认为是概率分布,这就给很多问题的求解提供了便利。softmax函数在神经网络常常当做多分类神经网络输出的激活函数。softmax是将每个神经元输出通通进行e指数变换,并分别除以这些变换后结果的和,从而得到0~1之间的分值。softmax公式如下:
其中,
是第j个输出神经元激励。
是softmax层前一层第j个神经元的输出。
softmax具体过程可以参考下图,其中的softmax层输出为y(j),softmax层输入为z(j)。注意所有激励输出和为1
很明显,y1,y2,y3的和就为1,是不是很像概率问题。softmax常常用来作为多分类问题分类器的输出。
关于softmax的解释,这里有些阅读资源比较有价值:
英文版:http://neuralnetworksanddeeplearning.com/chap3.html
翻译版:https://hit-scir.gitbooks.io/neural-networks-and-deep-learning-zh_cn/content/chap3/c3s4.html
softmax的损失函数
softmax损失函数如下:
特别注意,一个样本的输出损失就是这么个公式计算得到的,没有累加和。
是正确分类的神经元输出,比如这要分类动物。样本1的实际图片是狗子,但是呢分类器不行输出的的三个神经元分数分别为0.1 、0.2、0.7,他们分别代表 小猫、狗子和小猪。显然小猪分数最高,但是呢我们要盯着狗子神经元的输出看,纳尼才0.2。这个时候这个样本计算得到的损失就是 -ln(0.2)= 1.609。我们训练的目的就是通过调整参数来降低这个loss。
有的地方说softmax的损失函数是交叉熵损失,有的地方用log似然函数,实际上在多分类问题中这些公式是等价相通的。我在另一篇博客比较具体的说明了这个问题,传送门:
https://blog.csdn.net/weixin_39704651/article/details/97392322
softmax反向传播推导以及代码实现
在反向传播中,我们经常用链式法则来串联整个反向传播的过程从而计算参数的梯度值。所以这里单独研究softmax层。我们都知道softmax的公式,根据高中知识也很容易知道它对于输入的求导是多少。温习下softmax公式:
我们假设样本的正确分类在第 i 个神经元。当
时
当 i 不等于 j 时
我们顺着这个推导再前进一点点,当
时
当 i 不等于 j 时
是不是非常简洁,这也是softmax交叉熵(或者说log似然)损失函数结合的一大好处,在反向传播的时候非常容易计算。笔者做cs231n的作业的时候,被这部分内容饶了很久,下面配合着代码对softmax做简要的分析
softmax代码实现如下;
def softmax(x):
x -= np.max(x)
a = np.exp(x)
b = a / np.sum(a, axis=1, keepdims = True)
return b
因为在跑代码的过程中发现经常会出现参数为NAN的情况,cs231n笔记解释是因为指数使得输出非常大,大数值可能导致计算不稳定,也就是计算机不行,数要小些,效果还要一样,因此可以用归一化技巧。即在分式的分子和分母都乘以一个常数C,可以得到数学上等价的公式:
其中
的值可自由选择,不会影响计算结果,通过使用这个技巧可以提高计算中的数值稳定性。通常将
设为
。该技巧简单地说,就是应该将向量
中的数值进行平移,使得最大值为0。也就是上述 x -= np.max(x)
接下里看softmax的loss和反向传播代码,这里给出cs231n作业2第一部分的代码
def softmax_loss(x, y):
shifted_logits = x - np.max(x, axis=1, keepdims=True) # 确保数值稳定性,输入减去最大值
Z = np.sum(np.exp(shifted_logits), axis=1, keepdims=True ) # 求softmax公式分母部分
log_probs = shifted_logits - np.log(Z) # 见下文解释1
loss = -np.sum(log_probs[np.arange(N), y]) / N # 从所有神经元的softmax输出中找到正确输出的神经元的分数求loss
probs = np.exp(log_probs) # 见下文解释2,即softmax输出分值
N = x.shape[0]
dx = probs.copy()
dx[np.arange(N), y] -= 1 # 见下文解释3
dx /= N
return loss, dx
解释1:注意第3行代码,是一个等价公式
解释2: 由上面的公式,当然
。
解释3:上文已经推导了
对
的偏导数,对于
,即对应 dx[np.arange(N), y] -= 1
望与童鞋们一起进步
以上