γ、β
存在问题
在对输入做完标准化后,可能会出现以下情况:网络中间某一层学习到特征数据本身分布在sigmoid激活函数的两侧,标准化会强制把输入的均值限制为0、标准差限制为1,这样就把数据变换成分布在sigmoid激活函数的中间部分,即破坏了网络中间某一层所学习到的特征分布。
解决方法
变换重构,引入可学习参数γ、β,所有γ和所有β分别初始化为1和0,每一层中的每一个神经元都有自己的γ和β。前向传播时,要使用每个神经元自己的γ和β。在反向传播时,会计算所有γ和所有β的梯度,即dγ和dβ。在反向传播结束后,每一个γ和β就可以分别通过优化器(如SGD、Adam等)与对应的梯度dγ和dβ进行更新了。
前向传播
import numpy as np
def batchnorm_forward(x, gamma, beta, bn_param):
r"""
args:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
Returns:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', np.zeros(D, dtype=x.dtype))
running_var = bn_param.get('running_var', np.zeros(D, dtype=x.dtype))
out, cache = None, None
if mode == 'train':
sample_mean = np.mean(x, axis=0)
sample_var = np.var(x, axis=0)
out_ = (x - sample_mean) / np.sqrt(sample_var + eps)
running_mean = momentum * running_mean + (1 - momentum) * sample_mean
running_var = momentum * running_var + (1 - momentum) * sample_var
out = gamma * out_ + beta
cache = (out_, x, sample_var, sample_mean, eps, gamma, beta)
elif mode == 'test':
scale = gamma / np.sqrt(running_var + eps)
out = x * scale + (beta - running_mean * scale)
bn_param['running_mean'] = running_mean
bn_param['running_var'] = running_var
return out, cache
反向传播
import numpy as np
def batchnorm_backward(dout, cache):
r"""
args:
- dout: Upstream derivatives, of shape (N, D)
- cache: Variable of intermediates from batchnorm_forward.
Returns:
- dx: Gradient with respect to inputs x, of shape (N, D)
- dgamma: Gradient with respect to scale parameter gamma, of shape (D,)
- dbeta: Gradient with respect to shift parameter beta, of shape (D,)
"""
dx, dgamma, dbeta = None, None, None
out_, x, sample_var, sample_mean, eps, gamma, beta = cache
N = x.shape[0]
dout_ = gamma * dout
dvar = np.sum(dout_ * (x - sample_mean) * -0.5 * (sample_var + eps) ** -1.5, axis=0)
dx_ = 1 / np.sqrt(sample_var + eps)
dvar_ = 2 * (x - sample_mean) / N
di = dout_ * dx_ + dvar * dvar_
dmean = -1 * np.sum(di, axis=0)
dmean_ = np.ones_like(x) / N
dx = di + dmean * dmean_
dgamma = np.sum(dout * out_, axis=0)
dbeta = np.sum(dout, axis=0)
return dx, dgamma, dbeta
【参考文献】