本文参考:深度学习之3——梯度爆炸与梯度消失
梯度消失和梯度爆炸的根源:深度神经网络结构、反向传播算法
目前优化神经网络的方法都是基于反向传播的思想,即根据损失函数计算的误差通过反向传播的方式,指导深度网络权值的更新。
为什么神经网络优化用到梯度下降的优化方法?
深度网络是由许多非线性层(带有激活函数)堆叠而成,每一层非线性层可以视为一个非线性函数 f(x) ,因此整个深度网络可以视为一个复合的非线性多元函数:
我们的目的是希望这个多元函数可以很好的完成输入到输出的映射,假设不同的输入,输出最优解是 g(x) ,那么优化深度网络就是为了寻找到合适的权值,满足损失取得最小值点,比如简单的损失函数平方差(MSE):
在数学中寻找最小值问题,采用梯度下降的方法再合适不过了。
1.深层网络的结构;
在反向传播对激活函数进行求导的时候,如果此部分大于1,那么随着层数的增加,求出的梯度的更新将以指数形式增加,发生梯度爆炸。如果此部分小于1,那么随着层数的增加求出的梯度更新的信息会以指数形式衰减,如果某一次等于0,那么相乘就全为0这就是梯度消失,发生梯度消失。梯度消失/梯度爆炸:二者问题问题都是因为网络太深,网络权值更新不稳定造成的。本质上是因为梯度反向传播中的连乘效应(小于1连续相乘多次)。
反向传播更新参数w的更新过程,对于参数。梯度消失时,越靠近输入层的参数w越是几乎纹丝不动;梯度爆炸时,越是靠近输入层的参数w越是上蹿下跳。稳定和突变都不是我们想要的,我们想要的是参数w朝着误差减小的方向稳步变化。
从深层网络角度来说,不同的层学习的速度差异很大,表现为网络中靠近输出的层学习的情况很好,靠近输入层的学习很慢,甚至即使训练了很久,前基层的权值和刚开始初始化的值差不多,因此梯度消失和梯度爆炸的根本原因在于反向传播算法的不足。
2.激活函数的角度:
如果激活函数的选择不合适,比如使用Sigmoid,梯度消失会很明显。
下图为sigmoid函数(又叫Logistic)和Tanh函数的导数,Logistic导数最大的时候也只有0.25,其余时候远小于0.25,因此如果每层的激活函数都为sigmoid函数的话,很容易导致梯度消失问题,Tanh函数的导数峰值是1那也仅仅在取值为0的时候,其余时候都是小于1,因此通过链式求导之后,Tanh函数也很容易导致梯度消失。
3.梯度消失、爆炸的解决方案
1.预训练和微调
预训练:无监督逐层训练,每次训练一层隐藏点,训练时将上一层隐节点的输出作为输入,而本层隐节点的输出作为下一层隐节点的输入。称为逐层预训练。在预训练完成后还要对整个网络进行微调。
2.梯度剪切、正则
梯度剪切又叫梯度截断,是防止梯度爆炸的一种方法,其思想是设置一个梯度剪切阈值,更新梯度的时候,如果梯度超过这个阈值,那就将其强制限制在这个范围之内。
3.relu、leakyrelu、elu等激活函数
从relu的函数特性我们知道,在小于0的时候梯度为0,大于0的时候梯度恒为1,那么此时就不会再存在梯度消失和梯度爆炸的问题了,因为每层的网络得到的梯度更新速度都一样。
relu的主要贡献:
- 解决了梯度消失和梯度爆炸的问题
- 计算方便计算速度快(梯度恒定为0或1)
- 加速了网络的训练
缺陷:
- 由于负数部分恒为0,导致一些神经元无法激活(可通过设置小学习率部分解决)
- 输出并不是零中心化的
尽管relu也有缺点,但是仍然是目前使用最多的激活函数
leakyrelu就是为了解决relu的0区间带来的影响,在小于0的区间,梯度为很小的数(非零),leakyrelu解决了0区间带来的影响,而且包含了relu所有优点。
elu激活函数也是为了解决relu的0区间带来的影响,但是elu相对于leakyrelu来说,计算要耗时一些(有e的幂计算)。
4.Batch Normalization(批规范化)
详细阅读文章:
关于BN,这里举个简单的例子来说明:
正向传播中:,反向过程中:
反向传播过程中 w 的大小影响了梯度消失和爆炸,BN就是通过对每一层的输出规范为均值和方差一致的方法,消除了 w 带来的放大缩小的影响,进而解决梯度消失和爆炸的问题,或者可以理解为BN将输出从饱和区拉到了非饱和区(比如Sigmoid函数)。
5.残差结构
详细阅读文章:
自从残差网络提出后,几乎所有的深度网络都离不开残差的身影,相比较之前的几层,几十层的深度网络,残差网络可以很轻松的构建几百层而不用担心梯度消失过快的问题,原因在于残差的捷径,我们可以抽象的理解为残差网络就是把里面的连乘变成了连加,这样就解决了某一个权值求导为0的情况的影响,即:局部不会影响全局。
6.LSTM(长短期记忆网络)
在RNN(循环神经网络)网络结构中,由于使用sigmoid或者Tanh函数,所以很容易导致梯度消失的问题,即在相隔很远的时刻时,前者对后者的影响几乎不存在了,LSTM的机制正是为了解决这种长期依赖问题。具体的关于RNN和LSTM的解释,后续我学到了再写一篇文章。