一、为什么要Normalization?
ICS问题:由于数据尺度/分布异常,导致训练困难
由上图中的D(H1)=n*D(x)*D(W)=1可知,第一个隐藏层的输出等于上一层的输入的方差和二者之间权重的方差的连乘,所以如果数据的方差发生微小变化,那么随着网络的加深,这个变化会越来越明显,从而导致梯度消失或梯度爆炸
所以数据尺度或分布发生变化,则会导致模型难以训练
进行Normalization就能控制和约束数据的尺度,使得数据在一个良好的尺度和分布范围内,从而有助于模型的训练
二、常见的Normalization方法
2.1 Layer Normalization( LN)
说明:
因为BN是从特征数的维度出发,按照batch计算均值和方差,而在变长的网络中,如RNN,没有办法按照BN的计算方式来计算均值和方差
如上图中,不同的batch中的数据对应的特征数不同,所以没有办法按照batch计算均值和方差
2.1.1 nn.LayerNorm
nn.LayerNorm(normalized_shape,
eps=1e-05,
elementwise_affine=True)
主要参数:
- normalized_shape:该层特征形状,
- eps:分母修正项
- elementwise_affine:是否需要affine transform
注意:
normalized_shape参数输入的特征形状要求是C*H*W,而特征图的shape是B*C*H*W,所以输入时要注意处理——feature_maps_bs.size()[1:]
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ======================================== nn.layer norm
flag = 1
# flag = 0
if flag:
batch_size = 8
num_features = 6
features_shape = (3, 4)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
# feature_maps_bs shape is [8, 6, 3, 4], B * C * H * W
# ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=True)
# ln = nn.LayerNorm(feature_maps_bs.size()[1:], elementwise_affine=False)
# ln = nn.LayerNorm([6, 3, 4])
ln = nn.LayerNorm([6, 3])
output = ln(feature_maps_bs)
print("Layer Normalization")
print(ln.weight.shape)
print(feature_maps_bs[0, ...])
print(output[0, ...])
注意:这里的weight是对应Normalization公式里的γ和β,由weight的size可以看到,LN的确是按一个batch逐元素计算的
当elementwise_affine设置为false,ln.weight就没有,所以说明该参数就是对应Normalization公式里的γ和β
LN可以通过normalized_shape参数,使得在指定的shape上进行Normalization
注意:指定的shape,必须是按照BCHW,从后往前的连续形式输入shape,如果不连续,或者不是从W开始的,就会报错
2.2 Instance Normalization( IN)
说明:
该方法起因于图像领域,因为一个batch图像数据它们有不同的风格和内容,所以不能将其混为一谈直接计算均值和方差,所以就提出了逐通道的计算均值和方差
这里的逐通道的计算均值和方差是按照intance的,也就是按照每一个特征图的通道计算
2.2.1 nn.InstanceNorm
nn.InstanceNorm2d(num_features,
eps=1e-05,
momentum=0.1,
affine=False,
track_running_stats=False)
主要参数:
- num_features:一个样本特征数量(最重要)
- eps:分母修正项
- momentum:指数加权平均估计当前mean/var
- affine:是否需要affine transform
- track_running_stats:是训练状态,还是测试状态
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ======================================== nn.instance norm 2d
flag = 1
# flag = 0
if flag:
batch_size = 3
num_features = 3
momentum = 0.3
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 4D
print("Instance Normalization")
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
instance_n = nn.InstanceNorm2d(num_features=num_features, momentum=momentum)
for i in range(1):
outputs = instance_n(feature_maps_bs)
print(outputs)
# print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
# print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
# print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
# print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))
运行结果:
由上图可知,output结果都是0,因为输入数据的每一个通道的数据都是一样的,每个通道的数据的均值和该通道上的数是一样的,所以求均值和方差的结果都是0,由此就可知,IN确实是以实例出发,按照通道计算均值和方差
2.3 Group Normalization( GN)
说明:
样本数越多,估计的均值和方差就会越准,而在一些模型数据量特别大的情况下,GPU只能容纳两个甚至一个batch的数据,因此在这种情况下均值和方差的估计值不准,导致BN方法失效
2.3.1 nn.GroupNorm
nn.GroupNorm(num_groups,
num_channels,
eps=1e-05,
affine=True)
主要参数:
- num_groups:分组数
- num_channels:通道数(特征数)
- eps:分母修正项
- affine:是否需要affine transform
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
# ======================================== nn.grop norm
flag = 1
# flag = 0
if flag:
batch_size = 2
num_features = 4
num_groups = 3 # 3 Expected number of channels in input to be divisible by num_groups
features_shape = (2, 2)
feature_map = torch.ones(features_shape) # 2D
feature_maps = torch.stack([feature_map * (i + 1) for i in range(num_features)], dim=0) # 3D
feature_maps_bs = torch.stack([feature_maps * (i + 1) for i in range(batch_size)], dim=0) # 4D
gn = nn.GroupNorm(num_groups, num_features)
outputs = gn(feature_maps_bs)
print("Group Normalization")
print(gn.weight.shape)
print(outputs[0])
注意到这里的gn.weight.shape等于4,与num_features相同,所以说明了GN的γ和β是逐通道计算的
注意:设置的num_groups必须是要能被通道数整除的,否则会报错,通常会设置为2的n次幂
三、Normalization小结
BN:按照batch size的方向计算均值和方差,而且往往是在batch数较多的情况使用
LN:按照整个网络层计算均值和方差
IN:以每个feature map出发,按照通道来计算均值和方差
GN:对feature map进行分组,按照一个group来计算均值和方差