前言
本节将介绍循环神经网络中梯度的计算和存储方法,即 通过时间反向传播(back-propagation through time) 。
正向传播在循环神经网络中比较直观,而通过时间反向传播其实是反向传播在循环神经网络中的具体应用。我们需要将循环神经网络按时间步展开,从而得到模型变量和参数之间的依赖关系,并依据链式法则应用反向传播计算并存储梯度。
1. 定义模型
简单起见,我们考虑一个无偏差项的循环神经网络,且激活函数为恒等映射(
ϕ
(
x
)
=
x
\phi(x)=x
ϕ ( x ) = x )。设时间步
t
t
t 的输入为单样本
x
t
∈
R
d
\boldsymbol{x}_t \in \mathbb{R}^d
x t ∈ R d ,标签为
y
t
y_t
y t ,那么隐藏状态
h
t
∈
R
h
\boldsymbol{h}_t \in \mathbb{R}^h
h t ∈ R h 的计算表达式为
h
t
=
W
h
x
x
t
+
W
h
h
h
t
−
1
,
\boldsymbol{h}_t = \boldsymbol{W}_{hx} \boldsymbol{x}_t + \boldsymbol{W}_{hh} \boldsymbol{h}_{t-1},
h t = W h x x t + W h h h t − 1 ,
其中
W
h
x
∈
R
h
×
d
\boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}
W h x ∈ R h × d 和
W
h
h
∈
R
h
×
h
\boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}
W h h ∈ R h × h 是隐藏层权重参数。设输出层权重参数
W
q
h
∈
R
q
×
h
\boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}
W q h ∈ R q × h ,时间步
t
t
t 的输出层变量
o
t
∈
R
q
\boldsymbol{o}_t \in \mathbb{R}^q
o t ∈ R q 计算为
o
t
=
W
q
h
h
t
.
\boldsymbol{o}_t = \boldsymbol{W}_{qh} \boldsymbol{h}_{t}.
o t = W q h h t .
设时间步
t
t
t 的损失为
ℓ
(
o
t
,
y
t
)
\ell(\boldsymbol{o}_t, y_t)
ℓ ( o t , y t ) 。时间步数为
T
T
T 的损失函数
L
L
L 定义为
L
=
1
T
∑
t
=
1
T
ℓ
(
o
t
,
y
t
)
.
L = \frac{1}{T} \sum_{t=1}^T \ell (\boldsymbol{o}_t, y_t).
L = T 1 t = 1 ∑ T ℓ ( o t , y t ) .
我们将
L
L
L 称为有关给定时间步的数据样本的目标函数,并在本节后续讨论中简称为目标函数。
2. 模型计算图
为了可视化循环神经网络中模型变量和参数在计算中的依赖关系,我们可以绘制模型计算图,如图6.3所示。例如,时间步3的隐藏状态
h
3
\boldsymbol{h}_3
h 3 的计算依赖模型参数
W
h
x
\boldsymbol{W}_{hx}
W h x 、
W
h
h
\boldsymbol{W}_{hh}
W h h 、上一时间步隐藏状态
h
2
\boldsymbol{h}_2
h 2 以及当前时间步输入
x
3
\boldsymbol{x}_3
x 3 。
3. 方法
刚刚提到,图6.3中的模型的参数是
W
h
x
\boldsymbol{W}_{hx}
W h x ,
W
h
h
\boldsymbol{W}_{hh}
W h h 和
W
q
h
\boldsymbol{W}_{qh}
W q h 。与3.14节(正向传播、反向传播和计算图)中的类似,训练模型通常需要模型参数的梯度
∂
L
/
∂
W
h
x
\partial L/\partial \boldsymbol{W}_{hx}
∂ L / ∂ W h x 、
∂
L
/
∂
W
h
h
\partial L/\partial \boldsymbol{W}_{hh}
∂ L / ∂ W h h 和
∂
L
/
∂
W
q
h
\partial L/\partial \boldsymbol{W}_{qh}
∂ L / ∂ W q h 。 根据图6.3中的依赖关系,我们可以按照其中箭头所指的反方向依次计算并存储梯度。为了表述方便,我们采用运算符prod表达链式法则。
首先,目标函数有关各时间步输出层变量的梯度
∂
L
/
∂
o
t
∈
R
q
\partial L/\partial \boldsymbol{o}_t \in \mathbb{R}^q
∂ L / ∂ o t ∈ R q 很容易计算:
∂
L
∂
o
t
=
∂
ℓ
(
o
t
,
y
t
)
T
⋅
∂
o
t
.
\frac{\partial L}{\partial \boldsymbol{o}_t} = \frac{\partial \ell (\boldsymbol{o}_t, y_t)}{T \cdot \partial \boldsymbol{o}_t}.
∂ o t ∂ L = T ⋅ ∂ o t ∂ ℓ ( o t , y t ) .
下面,我们可以计算目标函数有关模型参数
W
q
h
\boldsymbol{W}_{qh}
W q h 的梯度
∂
L
/
∂
W
q
h
∈
R
q
×
h
\partial L/\partial \boldsymbol{W}_{qh} \in \mathbb{R}^{q \times h}
∂ L / ∂ W q h ∈ R q × h 。根据图6.3,
L
L
L 通过
o
1
,
…
,
o
T
\boldsymbol{o}_1, \ldots, \boldsymbol{o}_T
o 1 , … , o T 依赖
W
q
h
\boldsymbol{W}_{qh}
W q h 。依据链式法则,
∂
L
∂
W
q
h
=
∑
t
=
1
T
prod
(
∂
L
∂
o
t
,
∂
o
t
∂
W
q
h
)
=
∑
t
=
1
T
∂
L
∂
o
t
h
t
⊤
.
\frac{\partial L}{\partial \boldsymbol{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{o}_t} \boldsymbol{h}_t^\top.
∂ W q h ∂ L = t = 1 ∑ T prod ( ∂ o t ∂ L , ∂ W q h ∂ o t ) = t = 1 ∑ T ∂ o t ∂ L h t ⊤ .
其次,我们注意到隐藏状态之间也存在依赖关系。 在图6.3中,
L
L
L 只通过
o
T
\boldsymbol{o}_T
o T 依赖最终时间步
T
T
T 的隐藏状态
h
T
\boldsymbol{h}_T
h T 。因此,我们先计算目标函数有关最终时间步隐藏状态的梯度
∂
L
/
∂
h
T
∈
R
h
\partial L/\partial \boldsymbol{h}_T \in \mathbb{R}^h
∂ L / ∂ h T ∈ R h 。依据链式法则,我们得到
∂
L
∂
h
T
=
prod
(
∂
L
∂
o
T
,
∂
o
T
∂
h
T
)
=
W
q
h
⊤
∂
L
∂
o
T
.
\frac{\partial L}{\partial \boldsymbol{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{o}_T}, \frac{\partial \boldsymbol{o}_T}{\partial \boldsymbol{h}_T} \right) = \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_T}.
∂ h T ∂ L = prod ( ∂ o T ∂ L , ∂ h T ∂ o T ) = W q h ⊤ ∂ o T ∂ L .
接下来对于时间步
t
<
T
t < T
t < T , 在图6.3中,
L
L
L 通过
h
t
+
1
\boldsymbol{h}_{t+1}
h t + 1 和
o
t
\boldsymbol{o}_t
o t 依赖
h
t
\boldsymbol{h}_t
h t 。依据链式法则, 目标函数有关时间步
t
<
T
t < T
t < T 的隐藏状态的梯度
∂
L
/
∂
h
t
∈
R
h
\partial L/\partial \boldsymbol{h}_t \in \mathbb{R}^h
∂ L / ∂ h t ∈ R h 需要按照时间步从大到小依次计算:
∂
L
∂
h
t
=
prod
(
∂
L
∂
h
t
+
1
,
∂
h
t
+
1
∂
h
t
)
+
prod
(
∂
L
∂
o
t
,
∂
o
t
∂
h
t
)
=
W
h
h
⊤
∂
L
∂
h
t
+
1
+
W
q
h
⊤
∂
L
∂
o
t
\frac{\partial L}{\partial \boldsymbol{h}_t} = \text{prod} (\frac{\partial L}{\partial \boldsymbol{h}_{t+1}}, \frac{\partial \boldsymbol{h}_{t+1}}{\partial \boldsymbol{h}_t}) + \text{prod} (\frac{\partial L}{\partial \boldsymbol{o}_t}, \frac{\partial \boldsymbol{o}_t}{\partial \boldsymbol{h}_t} ) = \boldsymbol{W}_{hh}^\top \frac{\partial L}{\partial \boldsymbol{h}_{t+1}} + \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_t}
∂ h t ∂ L = prod ( ∂ h t + 1 ∂ L , ∂ h t ∂ h t + 1 ) + prod ( ∂ o t ∂ L , ∂ h t ∂ o t ) = W h h ⊤ ∂ h t + 1 ∂ L + W q h ⊤ ∂ o t ∂ L
将上面的递归公式展开,对任意时间步
1
≤
t
≤
T
1 \leq t \leq T
1 ≤ t ≤ T ,我们可以得到目标函数有关隐藏状态梯度的通项公式
∂
L
∂
h
t
=
∑
i
=
t
T
(
W
h
h
⊤
)
T
−
i
W
q
h
⊤
∂
L
∂
o
T
+
t
−
i
.
\frac{\partial L}{\partial \boldsymbol{h}_t} = \sum_{i=t}^T {\left(\boldsymbol{W}_{hh}^\top\right)}^{T-i} \boldsymbol{W}_{qh}^\top \frac{\partial L}{\partial \boldsymbol{o}_{T+t-i}}.
∂ h t ∂ L = i = t ∑ T ( W h h ⊤ ) T − i W q h ⊤ ∂ o T + t − i ∂ L .
由上式中的指数项可见,当时间步数
T
T
T 较大或者时间步
t
t
t 较小时,目标函数有关隐藏状态的梯度较容易出现 衰减 和 爆炸 。这也会影响其他包含
∂
L
/
∂
h
t
\partial L / \partial \boldsymbol{h}_t
∂ L / ∂ h t 项的梯度,例如隐藏层中模型参数的梯度
∂
L
/
∂
W
h
x
∈
R
h
×
d
\partial L / \partial \boldsymbol{W}_{hx} \in \mathbb{R}^{h \times d}
∂ L / ∂ W h x ∈ R h × d 和
∂
L
/
∂
W
h
h
∈
R
h
×
h
\partial L / \partial \boldsymbol{W}_{hh} \in \mathbb{R}^{h \times h}
∂ L / ∂ W h h ∈ R h × h 。 在图6.3中,
L
L
L 通过
h
1
,
…
,
h
T
\boldsymbol{h}_1, \ldots, \boldsymbol{h}_T
h 1 , … , h T 依赖这些模型参数。 依据链式法则,我们有
∂
L
∂
W
h
x
=
∑
t
=
1
T
prod
(
∂
L
∂
h
t
,
∂
h
t
∂
W
h
x
)
=
∑
t
=
1
T
∂
L
∂
h
t
x
t
⊤
,
∂
L
∂
W
h
h
=
∑
t
=
1
T
prod
(
∂
L
∂
h
t
,
∂
h
t
∂
W
h
h
)
=
∑
t
=
1
T
∂
L
∂
h
t
h
t
−
1
⊤
.
\begin{aligned} \frac{\partial L}{\partial \boldsymbol{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{x}_t^\top,\\ \frac{\partial L}{\partial \boldsymbol{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \boldsymbol{h}_t}, \frac{\partial \boldsymbol{h}_t}{\partial \boldsymbol{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \boldsymbol{h}_t} \boldsymbol{h}_{t-1}^\top. \end{aligned}
∂ W h x ∂ L ∂ W h h ∂ L = t = 1 ∑ T prod ( ∂ h t ∂ L , ∂ W h x ∂ h t ) = t = 1 ∑ T ∂ h t ∂ L x t ⊤ , = t = 1 ∑ T prod ( ∂ h t ∂ L , ∂ W h h ∂ h t ) = t = 1 ∑ T ∂ h t ∂ L h t − 1 ⊤ .
每次迭代中,我们在依次计算完以上各个梯度后,会将它们存储起来,从而避免重复计算。例如,由于隐藏状态梯度
∂
L
/
∂
h
t
\partial L/\partial \boldsymbol{h}_t
∂ L / ∂ h t 被计算和存储,之后的模型参数梯度
∂
L
/
∂
W
h
x
\partial L/\partial \boldsymbol{W}_{hx}
∂ L / ∂ W h x 和
∂
L
/
∂
W
h
h
\partial L/\partial \boldsymbol{W}_{hh}
∂ L / ∂ W h h 的计算可以直接读取
∂
L
/
∂
h
t
\partial L/\partial \boldsymbol{h}_t
∂ L / ∂ h t 的值,而无须重复计算它们。此外,反向传播中的梯度计算可能会依赖变量的当前值。它们正是通过正向传播计算出来的。 举例来说,参数梯度
∂
L
/
∂
W
h
h
\partial L/\partial \boldsymbol{W}_{hh}
∂ L / ∂ W h h 的计算需要依赖隐藏状态在时间步
t
=
0
,
…
,
T
−
1
t = 0, \ldots, T-1
t = 0 , … , T − 1 的当前值
h
t
\boldsymbol{h}_t
h t (
h
0
\boldsymbol{h}_0
h 0 是初始化得到的)。这些值是通过从输入层到输出层的正向传播计算并存储得到的。
小结
通过时间反向传播是反向传播在循环神经网络中的具体应用。
当总的时间步数较大或者当前时间步较小时,循环神经网络的梯度较容易出现衰减或爆炸。