学习BatchNormalization

BatchNormalization

深度学习的理论基础

统计学习领域有个很重要的假设:独立同分布假设,即假设训练数据和测试数据是满足相同分布的。而神经网络的训练实际上就是在拟合训练数据的分布。如果不满足独立同分布假设,那么训练得到的模型只适用于当前数据的分布模式,泛化能力肯定不好。

机器学习方法在运用之前,会先对数据做归一化处理,这样的话模型会认为训练数据和测试数据满足相同分布(即均值为0,方差为1的标准正态),这样一来模型的泛化能力会得到提高(物理世界中大多部分都符合正态分布)。其次,如果不做归一化,使用mini-batch梯度下降法训练的时候,每批训练数据的分布不相同,那么网络就要在每次迭代的时候去适应不同的分布,这样会大大降低网络的训练速度。综合以上两点,所以需要对数据做归一化预处理。如果是mini-batch梯度下降法,每个batch都可以计算出均值和方差,最终记录的均值和方差是所有batches均值和方差的期望,这就是我们今天的主角:BatchNormalization。当然也有其它更复杂的记录方式,如pytorch使用的滑动平均。

网络模型的发展

深度学习,顾名思义,深度是提升效果的重点之一。从VGG到Alex,再到ResNet,DenseNet等等,网络模型设计的越来越深,体量越来越大,模型的拟合能力也越来越强,解决复杂任务的能力也逐渐提升。这几年深度学习在图像分类、目标检测、语义分割等任务上都取得了SOTA,并且各种方法层出不穷,已经霸占了各大顶会。

深度学习训练法则的弊端

深度学习的整体架构,无非就是输入数据,网络模型前向计算,输出预测,送入loss与目标值计算损失,根据损失利用链式求导法则进行反向传播,更新参数,最终模型学习到数据分布,理想情况下在测试集上也表现出不错的泛化能力,至此任务完成。当然以上是监督式的学习方法,还要半监督、无监督的学习方法。
随着网络层数的不断加深,这种链式求导的反向传播法则会暴露出一定的弊端,比如梯度爆炸和梯度消失,因为链式求导法则是级乘形式的。这里我们以全连接层为例,如图1所示。
在这里插入图片描述
                     图1. 全连接层神经网络
                     
如果我们想要更新参数 θ \theta θ,那么需要从loss一路传播到 θ \theta θ,因为
 
                     y ′ = x w 1 w 2 w 3 w 4 w 5 y^{'}=xw_1w_2w_3w_4w_5 y=xw1w2w3w4w5
                     
其中 h 1 , h 2 , h 3 , h 4 h_1,h_2,h_3,h_4 h1,h2,h3,h4是中间层的值。根据链式求导法则,如图2所示我们需要经过层层求导,才能得到loss对 θ \theta θ的偏微分。
在这里插入图片描述
                       图2. 链式求导

我们可以看到每一个环节之间都是以级乘的形式进行连接的,当每个环节的值过大或者过小时,传递到最终的位置,可能就会得到一个非常小(梯度消失)或者非常大(梯度爆炸)的值,这对于模型学习和参数更新都会造成致命的伤害,导致模型无法收敛或学习不到有用的信息。
除了反向传播,在正向传播的过程中,也会存在问题。在每一层隐藏层之后,我们都会使用一种非线性的激活函数(sigmoid,relu…)来增加神经网络的非线性拟合能力,因为激活函数的特性曲线,只有特定范围内的输入才能激活当前神经元。例如sigmoid,如图3所示。
在这里插入图片描述
               图3. Sigmoid激活函数曲线
可以看出,sigmoid激活函数只对中间一段范围内的输入反应较为灵敏,两边的发散范围不太灵敏。那么我们为了提高神经元的灵敏程度,就要控制一下输入数据的分布范围,怎么控制呢?BatchNormalization就是方法之一。

解决方法—BatchNormalization

BatchNormalization(BN)层是放在特征提取层和激活函数层之间的,即在送入激活函数之前,对数据进行批量归一化,调整输入数据的分布,送入激活函数进行激活。BN的算法原理如图4所示。
在这里插入图片描述
              图4. BatchNormalization的算法原理
BN层首先将输入的数据进行去均值去方差(白化)操作,使其服从于均值和方差都为零的高斯分布,然后进行一个可学习的线性变化,这样做的目的是为了让模型自己去拟合数据原先的分布,不至于使得数据经过白化操作后失去了原先的分布。
在卷积神经网络里,BN层对于输入的数据(Batch,channel,height,width),在整个batch内计算每个通道(batch, height,width)的均值和方差,然后进行白化操作,最后进行线性变换。图5为经过BN之后的数据分布效果图。即,每一个batch有channel个均值和方差,新参数 γ , β \gamma,\beta γ,β。BN算法只关注每一个channel,没有关注channel之间的关系,可以说只关注了spacial信息没有关注channel信息。
在这里插入图片描述
               图5. BN之后的数据分布

BatchNormalization注意事项

我们训练时使用一个batch的数据,因此可以计算batch内多个样本的均值和方差,但是预测时只有一个样本数据,所以均值方差都是0。这时BN层什么也不干,原封不动的输出,这肯定会用问题,因为模型训练时都是进过处理的,但是测试时又没有,那么训练与测试不一致,结果肯定不对。

解决的方法是使用训练的所有数据,也就是所谓的数量上的统计。原文中使用的就是这种方法,不过这需要训练完成之后在多出一个步骤。另外一种常见的办法就是基于momentum的指数衰减,这种方法就是我们下面作业要完成的算法。
公式如下
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
这类似于低通滤波的原理,通过滑动平均所有batch的均值和方差,得到最接近于整体数据的均值和方差应用于测试阶段。

Pytorch里的BatchNormalization实现

torch.nn.BatchNorm2d(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)

参数说明:

  • num_features: 输入数据的通道数(因为BN是按照通道计算均值和方差的)
  • eps: 用来防止归一化时分母为零的情况
  • momentum: 滑动平均的参数,用来计算running_mean和running_var
  • affine:是否进行缩放操作
  • track_running_stats: 是否记录训练阶段的均值和方差,即running_mean和running_var

Batchnormalization的作用

  • BN层在激活函数层之前,将输入数据归一化到正态分布下,这样在反向求导的时候模型可以工作在激活函数的梯度敏感区域,从而提高模型的学习能力;
  • 归一化之后,模型就不用每次面对未知的分布去做拟合,提高模型的训练速度;
  • BN层中可学习的缩放操作可以看作引入了随机的小噪声,还可以在一定程度上提升模型的泛化性能;(也有可能是因为文章开头所说的BN将输入归一化到相同分布下,满足统计学习独立同分布的假设)。

猜你喜欢

转载自blog.csdn.net/Just_do_myself/article/details/124040299