线性条件随机场代码解读

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jeryjeryjery/article/details/82346482

  NERCRF是必不可少的环节,特地看了一遍CRF相关理论以及allennlpCRF的代码,特在这里笔记记录下来!

1.线性CRF简介

1.1一般形式

  关于线性条件随机场的详细介绍,请参考李航老师的《统计学习方法》或者这里,这里仅仅给出一般的公式定义。
  设 P ( Y | X ) 为线性链条件随机场,则在随机变量 X 取值为 x 的条件下,随机变量 Y 取值为 y 的条件概率具有如下形式(注意 x , y 都是序列):

P ( y | x ) = 1 Z ( x ) e x p ( i , k λ k t k ( y i 1 , y i , x , i ) + i , l μ l s l ( y i , x , i ) ) ( 11.10 )

其中,
Z ( x ) = y e x p ( i , k λ k t k ( y i 1 , y i , x , i ) + i , l μ l s l ( y i , x , i ) ) ( 11.11 )

式子中, t k 转移特征函数,依赖于当前和前一个位置 s l 状态特征函数,依赖于当前位置 λ k μ l 是对应的权重。 Z ( x ) 是规范化因子,求和是在所有可能的输出序列上进行的(注意这个所有可能不是任意的组合,这需要依赖于 x 的取值)。

1.2简化形式

  注意到条件随机场式(11.10)中同一特征在各个位置都有定义,可以对同一个特征在各个位置求和,将局部特征函数转哈U为一个全局特征函数,这样就可以将条件随机场写成权值向量和特征向量(包括转移特征和状态特征)的内积形式,即条件随机场的简化形式。
  首先将转移特征和状态特征及其权值用统一的符号表示,设有 K 1 个转换特征, K 2 个状态特征, K = K 1 + K 2 ,记:

f k ( y i 1 , y i , x , i ) = { t k ( y i 1 , y i , x , i ) , k = 1 , 2 , . . . , K s l ( y i , x , i ) , k = K 1 + l ; l = 1 , 2 , . . . , K 2 ( 11.12 )

然后,对转移与状态特征在各个位置i求和,记作:
f k ( y , x ) = i n f k ( y i 1 , y i , x , i ) , k = 1 , 2 , . . . , K ( 11.13 )

w k 表示特征 f k ( y , x ) 的权值,即:
w k = { λ k , k = 1 , 2 , . . . , K μ l , k = K 1 + l ; l = 1 , 2 , . . . , K 2 ( 11.14 )

于是用上面的简化形式,条件随机场可以表示为:
P ( y | x ) = 1 Z ( x ) e x p k = 1 K w k f k ( y , x ) ( 11.15 )

Z ( x ) = y e x p k = 1 K w k f k ( y , x ) ( 11.16 )

若以w表示权值向量,即:
w = ( w 1 , w 2 , . . . , w K ) T ,以 F ( y , x ) = ( f 1 ( y , x ) , f 2 ( y , x ) , . . . , f K ( y , x ) ) T ,则条件随机场可以写成向量 w F ( y , x ) 的内积形式:
P w ( y | x ) = 1 Z ( x ) e x p ( w F ( y , x ) ) ( 11.19 )

其中,
Z w ( x ) = y e x p ( w F ( y , x ) ) ( 11.20 )

1.3条件随机场的矩阵形式

  条件随机场还可以由矩阵表示,事实上,在代码实现中,我们肯定需要用到矩阵运算的!假设 P w ( y | x ) 是由式子 ( 11.15 ) ( 11.16 ) 给出的线性链条件随机场,表示对给定观测序列 x ,相应的标记序列 y 的条件概率。引进特殊的起点和终点标记 y 0 = s t a r t , y n + 1 = s t o p ,这时 P w ( y | x ) 可以通过矩阵形式表示。
  对观测序列 x 的每一个位置 i = 1 , 2 , . . . , n + 1 ,定义一个 m 阶矩阵( m 是标记 y i 取值的个数)

M i ( x ) = [ M i ( y i 1 , y i | x ) ] ( 11.21 )

M i ( y i 1 , y i | x ) = e x p ( W i ( y i 1 , y i | x ) ) ( 11.22 )

W i ( y i 1 , y i | x ) = k = 1 K w k f k ( y i 1 , y i | x ) ( 11.23 )

这样,给定观察序列x,相应标记序列y的非规范化概率可以通过该序列 n + 1 个矩阵适当元素的乘积 i = 1 n + 1 M i ( y i , y i | x ) 表示.于是,条件概率 P w ( y | x ) 是:
P w ( y | x ) = 1 Z ( x ) i = 1 n + 1 M i ( y i , y i | x ) ( 11.24 )

其中 Z w ( x ) 为规范化因子,是 ( n + 1 ) 个矩阵的乘积的 ( s t a r t , s t o p ) 元素:
Z w ( x ) = ( M 1 ( x ) M 2 ( x ) . . . M n + 1 ( x ) ) s t a r t , s t o p

注意, y 0 = s t a r t y n + 1 = s t o p 表示开始和终止状态,规范化因子 Z w ( x ) 是以 s t a r t 为起点 s t o p 为终点通过状态的所有路径 y 1 y 2 . . . y n 的非规范化概率 i = 1 n + 1 M i ( y i , y i | x ) 之和, 这个所有路径与 x 的取值也是息息相关的, x 能够决定各位置各标签的得分!.

  下面,我们将1.1和1.3的内容拼接起来,证明二者的一致性!

P ( y | x ) = 1 Z ( x ) e x p ( i , k λ k t k ( y i 1 , y i , x , i ) + i , l μ l s l ( y i , x , i ) )

这里,我们仅仅考虑后面的非规范化项。
e x p ( i , k λ k t k ( y i 1 , y i , x , i ) + i , l μ l s l ( y i , x , i ) ) =

e x p ( i ( k λ k t k ( y i 1 , y i , x , i ) + l μ l s l ( y i , x , i ) ) ) =

i e x p ( k λ k t k ( y i 1 , y i , x , i ) + l μ l s l ( y i , x , i ) ) =

i e x p ( k = 1 K w k f k ( y i 1 , y i , x , i ) ) =

i e x p ( W i ( y i 1 , y i | x ) ) = i M i ( y i 1 , y i | x ) = i M i ( x )

其中第二步到第三步是根据 e x p 相加可以展开为连乘的特性,第三步到第四步用到了 1.2中的简化形式,后面就是直接套用 M 的定义了。通过这个证明可以发现: 无论是先将所有得分先加起来做 e x p 还是直接先 e x p 再连乘,答案都是一样的!实现的时候,可以考虑这两种不同的方式!

2.前向-后向算法

  条件随机场(CRF)完全由特征函数 t k , s l 和对应的权重 λ k , μ l 确定,我们需要利用前向-后向算法,计算出给定输入序列和对应的实际标签序列的 l o g l i k e l i h o o d 概率值,然后通过最大化这个概率值,来更新上面特征和权重中的参数,实现学习的效果! 学习完这些参数之后,对于一个给定的输入序列,我们可以用维特比算法找出当前参数下得分最高的预测标签序列!
  这里讲解学习过程中一个很重要的算法,前向-后向算法!
  对于每个指标 i = 0 , 1 , . . . , n + 1 (包括了start和stop),定义前向向量 α i ( x ) :

α 0 ( y | x ) = { 1 , y = s t a r t 0 , 否则 ( 11.26 )

递推公式为:
α i T ( y i | x ) = α i 1 T ( y i 1 | x ) [ M i ( y i 1 , y i | x ) ] , i = 1 , 2 , . . . , n + 1 ( 11.27 )
又可以表示为:
α i T ( x ) = α i 1 T ( x ) M i ( x ) ( 11.28 )

α i T ( y i | x ) 表示在位置 i 的标记是 y i 并且到位置 i 的前部分标记序列的非规范化概率, y i 可取的值由 m 个,所以 α i ( x ) m 维列向量。为了更好的理解递推过程,我们可以对前几个 α 进行展开,当然 M 也进行相应的展开。
α 1 ( x ) = α 0 ( x ) M 1 ( x ) = α 0 ( x ) e x p ( k = 1 K w k f k ( y 0 , y 1 , x , 1 ) )

α 2 ( x ) = α 0 ( x ) e x p ( k = 1 K w k f k ( y 0 , y 1 , x , 1 ) ) e x p ( k = 1 K w k f k ( y 1 , y 2 , x , 2 ) )

. . . .

α i ( x ) = α 0 ( x ) e x p ( k = 1 K w k f k ( y 0 , y 1 , x , 1 ) ) . . . ( k = 1 K w k f k ( y i 1 , y i , x , i ) )

注意,这里的连乘是 e x p 连乘,转换为先连加在 e x p 是等价的。
  同样,对每个指标 i = 0 , 1 , . . . , n + 1 ,定义后向向量 β i ( x ) :
β n + 1 ( y n + 1 | x ) = { 1 , y n + 1 = s t o p 0 , 否则 ( 11.29 )

β i ( y i | x ) = [ M i ( y i , y i + 1 | x ) ] β i + 1 ( y i + 1 | x )

又可以表示为
β i ( x ) = M i + 1 ( x ) β i + 1 ( x )

β i ( y i | x ) 表示在位置 i 的标记为 y i 并且从 i + 1 n 的后部分标记序列的非规范化概率。
  由前向-后向定义不难得到:
Z ( x ) = α n T ( x ) 1 = 1 T β 1 ( x )

这里, 1 是元素均为1的 m 维列向量。
  你会发现,前后向算法本质上差不多,目的也是一样的,只是方向不同!

3.CRF优化问题

3.1正确序列的概率表达式

  我们这里以bi-LSTM + CRF为例子。假设输入为:

X = ( x 1 , x 2 , . . . , x n )

我们假设 P 是通过 b i L S T M 预测的各个位置各标签的得分矩阵,大小为 n k k 是独立的标签的总数量, P i , j 是句中第 i 个词预测第 j 个标签的得分。假设句子预测的标签为:
y = ( y 1 , y 2 , . . . , y n )

我们定义它的得分为:
s ( X , y ) = i = 1 n A y i 1 , y i + i = 1 n P i , y i y 0 s t a r t

你可能会觉得这里为什么和1.1小节中的式(11.10)略有不同,(11.10)中分子中多了exp是因为它做了一个softmax操作!本质上二者是一致的(分子部分)。其中 A i 1 , y i 对应转换特征,仅仅有一个转换特征,也就是 k = 1 ; P i , y i 是状态特征,仅仅有一个,也就是 l = 1 其中 A 标签转换得分矩阵,即从一种标签转化为另一种标签的分数,这是需要学习的参数;我们一般会为 A 加上两个标签 s t a r t e n d 标签,或者称为 s t o p ,分别对应 y 0 y n + 1
  我们的目标是让目标标签序列的总体得分尽可能的大。用 s o f t m a x 表示就是:
p ( y | X ) = e s ( X , y ) y ˘ Y X e s ( X , y ˘ )

这个式子和式 ( 11.10 ) 就完全等价了,其中 Y X 表示输入序列 X 可能预测的所有标签序列集合。在训练的时候,我们一般是最大化正确标签序列对应的 l o g p r o b a b i l i t y 值:
l o g ( p ( y | X ) ) = s ( X , y ) l o g ( y ˘ Y X e s ( X , y ˘ ) ) = s ( X , y ) l o g a d d y ˘ Y X s ( X , y ˘ )

所以,我们在计算这个 l o g l i k e l i h o o d 概率时,需要计算两部分,前一部分对应分子部分,后一部分对应分母部分。我们希望能够迭代计算出相应的值!

3.2 计算log-likelihood概率

  计算分两部分进行,第一部分是分子部分的值,也就是 s ( X , y ) ;第二部分是分母部分的值,也就是 l o g ( y ˘ Y X e s ( X , y ˘ ) )

3.2.1 分子部分

  首先给出 S ( X , y ) 分数计算方式:

S ( X , y ) = i = 1 n A y i 1 , y i + i = 1 n P i , y i

在代码实现中,我们是沿着句子中每个位置进行推进迭代的,也就是使用前向算法,我们列举每一步迭代的结果:
S 1 = i = 1 1 A y i 1 , y i + i = 1 1 P i , y i = A y 0 , y 1 + P 1 , y 1

S 2 = i = 1 2 A y i 1 , y i + i = 1 2 P i , y i = A y 0 , y 1 + A y 1 , y 2 + P 1 , y 1 + P 2 , y 2 = S 1 + A y 1 , y 2 + P 2 , y 2

. . .
S n = i = 1 n A y i 1 , y i + i = 1 n P i , y i = S n 1 + A y n 1 , y n + P n , y n

所以,我们在沿着句子中某个位置 i 进行迭代时,只需要一直记录对应的 S i 1 , A , P 这三项值,就能够计算出分子部分的值!

3.2.2 分母部分

  分母部分的计算相对来说比较麻烦,也需要构造每一步迭代项,分母部分计算公式如下:

Z ( X ) = l o g ( y ˘ Y X e s ( X , y ˘ ) ) = l o g ( y ˘ Y X e x p ( i = 1 n A y i 1 , y i + i = 1 n P i , y i ) )

我们也按照句子中的每一个位置进行展开!
Z 1 = l o g ( y ˘ t a g s e x p ( i = 1 1 A y i 1 , y i + i = 1 1 P i , y i ) ) =

l o g ( y ˘ t a g s e x p ( A y 0 ˘ , y 1 ˘ + P 1 , y ˘ 1 ) ) ( 3 , 1 )

其中 t a g s 表示所有标签集合, y ˘ i 表示位置 i 的对应的任意标签。
Z 2 = l o g ( y ˘ t a g s e x p ( i = 1 2 A y ˘ i 1 , y ˘ i + i = 1 n P i , y ˘ i ) ) =

l o g ( y ˘ t a g s e x p ( A y ˘ 0 , y ˘ 1 + A y ˘ 1 , y ˘ 2 + P 1 , y ˘ 1 + P 2 , y ˘ 2 ) ) =

l o g ( y ˘ t a g s e x p ( A y ˘ 0 , y ˘ 1 + P 1 , y ˘ 1 ) e x p ( A y ˘ 1 , y ˘ 2 + P 2 , y ˘ 2 ) ) =

l o g ( y ˘ t a g s e x p ( A y ˘ 0 , y ˘ 1 + P 1 , y ˘ 1 ) y ˘ t a g s e x p ( A y ˘ 1 , y ˘ 2 + P 2 , y ˘ 2 ) ) ( 3 , 2 )

从第三步到第四步可以展开为两个求和,因为长度为2的任意标签序列是两个长度为1的任意标签序列的任意组合。注意根据式 ( 3 , 1 )
Z 1 = l o g ( y ˘ t a g s e x p ( A y 0 ˘ , y 1 ˘ + P 1 , y ˘ 1 ) )
,
所以
e x p ( Z 1 ) = y ˘ t a g s e x p ( A y ˘ 0 , y ˘ 1 + P 1 , y ˘ 1 ) ( 3 , 4 )

将这个式子带入到式子 ( 3.2 ) 的前半部分,所以:
Z 2 = l o g ( e x p ( Z 1 ) y ˘ t a g s e x p ( A y ˘ 1 , y ˘ 2 + P 2 , y ˘ 2 ) ) =

l o g ( y ˘ t a g s e x p ( Z 1 + A y ˘ 1 , y ˘ 2 + P 2 , y ˘ 2 ) )

这里 Z 1 可以直接放进去是因为此时的 Z 1 已经计算出来了,是一个常量值了。
进行推广:
. . . .
Z n = l o g ( y ˘ t a g s e x p ( Z n 1 + A y ˘ n 1 , y ˘ n + P 2 , y ˘ 2 ) )

这样我们也找到了递推项,当我们沿着句子的每个位置进行迭代时,只需要一直记录对应的 Z i 1 , A , P 这三个值,就可以计算出分母部分的值。

4.CRF学习算法

  一般使用梯度下降法, t e n s o r f l o w p y t o r c h 等学习工具都提供了梯度下降法的支持!

5.源码解读

  下面对allennlp中提供的CRF源码进行解读!代码如下:

def allowed_transitions(constraint_type: str, tokens: Dict[int, str]) -> List[Tuple[int, int]]:
    """
    Given tokens and a constraint type, returns the allowed transitions. It will
    additionally include transitions for the start and end states, which are used
    by the conditional random field.

    Parameters
    ----------
    constraint_type : ``str``, required
        Indicates which constraint to apply. Current choices are "BIO" and "BIOUL".
    tokens : ``Dict[int, str]``, required
        A mapping {token_id -> token}. Most commonly this would be the value from
        Vocabulary.get_index_to_token_vocabulary()
        这应该是标签的tokens, 即所有的标签列表->id列表,类似于idx2tag

    Returns
    -------
    ``List[Tuple[int, int]]``
        The allowed transitions (from_token_id, to_token_id).

    这个方法的作用是,预选准备好所有的可能的标签之间的互换,可以排除很多不可能出现的情况
    """
    # 一般都需要先计算总的tag数,记住要加上起始和终止
    n_tags = len(tokens)
    start_tag = n_tags
    end_tag = n_tags + 1

    allowed = []
    # begin, inside, other, unique, Last?
    if constraint_type == "BIOUL":
        for i, (from_bioul, *from_entity) in tokens.items():
            for j, (to_bioul, *to_entity) in tokens.items():

                # 预先准备好可能的转换,避免做完全遍历的维特比?
                is_allowed = any([
                        # O can transition to O, B-* or U-*
                        # L-x can transition to O, B-*, or U-*
                        # U-x can transition to O, B-*, or U-*
                        from_bioul in ('O', 'L', 'U') and to_bioul in ('O', 'B', 'U'),
                        # B-x can only transition to I-x or L-x
                        # I-x can only transition to I-x or L-x
                        from_bioul in ('B', 'I') and to_bioul in ('I', 'L') and from_entity == to_entity
                ])

                if is_allowed:
                    allowed.append((i, j))

        # start transitions, 开始可以转换为other, begin unique
        for i, (to_bioul, *to_entity) in tokens.items():
            if to_bioul in ('O', 'B', 'U'):
                allowed.append((start_tag, i))

        # end transitions, other, unique, last可以转换为end
        for i, (from_bioul, *from_entity) in tokens.items():
            if from_bioul in ('O', 'L', 'U'):
                allowed.append((i, end_tag))
    # begin, inside, other
    elif constraint_type == "BIO":
        for i, (from_bio, *from_entity) in tokens.items():
            for j, (to_bio, *to_entity) in tokens.items():

                is_allowed = any([
                        # Can always transition to O or B-x
                        to_bio in ('O', 'B'),
                        # Can only transition to I-x from B-x or I-x
                        to_bio == 'I' and from_bio in ('B', 'I') and from_entity == to_entity
                ])

                if is_allowed:
                    allowed.append((i, j))

        # start transitions, 以start tag开始可以转换的情况
        for i, (to_bio, *to_entity) in tokens.items():
            if to_bio in ('O', 'B'):
                allowed.append((start_tag, i))

        # end transitions, 以end_tag结束可以转换的情况
        for i, (from_bio, *from_entity) in tokens.items():
            if from_bio in ('O', 'B', 'I'):
                allowed.append((i, end_tag))

    else:
        raise ConfigurationError(f"Unknown constraint type: {constraint_type}")

    return allowed


class ConditionalRandomField(torch.nn.Module):
    """
    This module uses the "forward-backward" algorithm to compute
    the log-likelihood of its inputs assuming a conditional random field model.

    See, e.g. http://www.cs.columbia.edu/~mcollins/fb.pdf

    Parameters
    ----------
    num_tags : int, required
        The number of tags.
    constraints : List[Tuple[int, int]], optional (default: None)
        An optional list of allowed transitions (from_tag_id, to_tag_id).
        These are applied to ``viterbi_tags()`` but do not affect ``forward()``.
        These should be derived from `allowed_transitions` so that the
        start and end transitions are handled correctly for your tag type.
    include_start_end_transitions : bool, optional (default: True)
        Whether to include the start and end transition parameters.
    """
    def __init__(self,
                 num_tags: int,
                 constraints: List[Tuple[int, int]] = None,
                 include_start_end_transitions: bool = True) -> None:
        super().__init__()
        self.num_tags = num_tags

        # transitions[i, j] is the logit for transitioning from state i to state j.
        self.transitions = torch.nn.Parameter(torch.Tensor(num_tags, num_tags))

        # _constraint_mask indicates valid transitions (based on supplied constraints).
        # Include special start of sequence (num_tags + 1) and end of sequence tags (num_tags + 2)
        if constraints is None:
            # All transitions are valid.
            constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(1.)
        else:
            constraint_mask = torch.Tensor(num_tags + 2, num_tags + 2).fill_(0.)
            for i, j in constraints:
                constraint_mask[i, j] = 1.

        self._constraint_mask = torch.nn.Parameter(constraint_mask, requires_grad=False)

        # Also need logits for transitioning from "start" state and to "end" state.
        self.include_start_end_transitions = include_start_end_transitions
        if include_start_end_transitions:
            self.start_transitions = torch.nn.Parameter(torch.Tensor(num_tags))
            self.end_transitions = torch.nn.Parameter(torch.Tensor(num_tags))

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_normal_(self.transitions)
        if self.include_start_end_transitions:
            torch.nn.init.normal_(self.start_transitions)
            torch.nn.init.normal_(self.end_transitions)

    def _input_likelihood(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """
        Computes the (batch_size,) denominator term for the log-likelihood, which is the
        sum of the likelihoods across all possible state sequences.
        logit: [batch_size, sequence_lenth, num_tag]表示这个batch训练集的目标序列

        这里计算的是log-likelihood的分母部分
        """
        batch_size, sequence_length, num_tags = logits.size()

        # Transpose batch size and sequence dimensions
        # 无论前向还是后向算法,都是根据词位置逐个推进的
        mask = mask.float().transpose(0, 1).contiguous()
        logits = logits.transpose(0, 1).contiguous()

        # Initial alpha is the (batch_size, num_tags) tensor of likelihoods combining the
        # transitions to the initial states and the logits for the first timestep.
        # 前向算法使用的是alpha,注意每个词对应位置的alpha都是num_tag维向量
        if self.include_start_end_transitions:
            alpha = self.start_transitions.view(1, num_tags) + logits[0]
        else:
            alpha = logits[0]

        # For each i we compute logits for the transitions from timestep i-1 to timestep i.
        # We do so in a (batch_size, num_tags, num_tags) tensor where the axes are
        # (instance, current_tag, next_tag), 即size对应为(实例,当前tag位置,下一tag位置)
        # 根据词的位置,依次推进前向算法
        for i in range(1, sequence_length):
            # 哪个维度对分数不影响,我们在那个维度扩散分数,
            # The emit scores are for time i ("next_tag") so we broadcast along the current_tag axis.
            emit_scores = logits[i].view(batch_size, 1, num_tags)
            # Transition scores are (current_tag, next_tag) so we broadcast along the instance axis.
            transition_scores = self.transitions.view(1, num_tags, num_tags)
            # Alpha is for the current_tag, so we broadcast along the next_tag axis.
            broadcast_alpha = alpha.view(batch_size, num_tags, 1)

            # Add all the scores together and logexp over the current_tag axis
            # 大牛这里都没有将它们expand成(batch_size, num_tags, num_tags),但结果是一样的
            inner = broadcast_alpha + emit_scores + transition_scores

            # In valid positions (mask == 1) we want to take the logsumexp over the current_tag dimension
            # of ``inner``. Otherwise (mask == 0) we want to retain the previous alpha.
            # tag有效,则持续累加; 否则保留之前的alpha值
            alpha = (util.logsumexp(inner, 1) * mask[i].view(batch_size, 1) +
                     alpha * (1 - mask[i]).view(batch_size, 1))

        # Every sequence needs to end with a transition to the stop_tag.
        # 到end-tag时,只有转换得分,没有logit得分,因为logit的长度仅仅是词长
        if self.include_start_end_transitions:
            stops = alpha + self.end_transitions.view(1, num_tags)
        else:
            stops = alpha

        # Finally we log_sum_exp along the num_tags dim, result is (batch_size,)
        return util.logsumexp(stops)

    def _joint_likelihood(self,
                          logits: torch.Tensor,
                          tags: torch.Tensor,
                          mask: torch.LongTensor) -> torch.Tensor:
        """
        Computes the numerator term for the log-likelihood, which is just score(inputs, tags)
        计算log-likelihood的分子部分

        logits: [batch_size, seq_len, tag_num]表示batch中每个词序列中,每个词预测各标签的分数
        tags: [batch_size, seq_len], 表示batch中每个词序列的真是标签序列
        mask: [batch_size, seq_len], 提示实际长度?
        """
        batch_size, sequence_length, num_tags = logits.data.shape

        # Transpose batch size and sequence dimensions:
        # 需要按照词的position推进
        logits = logits.transpose(0, 1).contiguous()
        mask = mask.float().transpose(0, 1).contiguous()
        tags = tags.transpose(0, 1).contiguous()

        # Start with the transition scores from start_tag to the first tag in each input
        # tag是该
        if self.include_start_end_transitions:
            score = self.start_transitions.index_select(0, tags[0])
        else:
            score = 0.0

        # Broadcast the transition scores to one per batch element
        # batch中各位置的转换时一致的
        broadcast_transitions = self.transitions.view(1, num_tags, num_tags).expand(batch_size, num_tags, num_tags)

        # Add up the scores for the observed transitions and all the inputs but the last
        # 将给定的序列的转换分数加上
        for i in range(sequence_length - 1):
            # Each is shape (batch_size,)
            current_tag, next_tag = tags[i], tags[i+1]

            # The scores for transitioning from current_tag to next_tag
            transition_score = (
                    broadcast_transitions
                    # Choose the current_tag-th row for each input
                    # 先gather row
                    .gather(1, current_tag.view(batch_size, 1, 1).expand(batch_size, 1, num_tags))
                    # Squeeze down to (batch_size, num_tags)
                    .squeeze(1)
                    # Then choose the next_tag-th column for each of those
                    # 从row中再gather col
                    .gather(1, next_tag.view(batch_size, 1))
                    # And squeeze down to (batch_size,)
                    .squeeze(1)
            )

            # The score for using current_tag
            # 使用当前标签对应的得分
            emit_score = logits[i].gather(1, current_tag.view(batch_size, 1)).squeeze(1)

            # Include transition score if next element is unmasked,
            # input_score if this element is unmasked.
            # 只有标签是有效的情况下,才会进行累加
            score = score + transition_score * mask[i + 1] + emit_score * mask[i]

        # Transition from last state to "stop" state. To start with, we need to find the last tag
        # for each instance.
        # 用mask来确定batch中每个序列的最后一个位置,然后取出每个序列的最后一个tag
        last_tag_index = mask.sum(0).long() - 1
        last_tags = tags.gather(0, last_tag_index.view(1, batch_size).expand(sequence_length, batch_size))

        # Is (sequence_length, batch_size), but all the columns are the same, so take the first.
        last_tags = last_tags[0]

        # Compute score of transitioning to `stop_tag` from each "last tag".
        if self.include_start_end_transitions:
            last_transition_score = self.end_transitions.index_select(0, last_tags)
        else:
            last_transition_score = 0.0

        # Add the last input if it's not masked.
        last_inputs = logits[-1]                                         # (batch_size, num_tags)
        last_input_score = last_inputs.gather(1, last_tags.view(-1, 1))  # (batch_size, 1)
        last_input_score = last_input_score.squeeze()                    # (batch_size,)

        score = score + last_transition_score + last_input_score * mask[-1]

        return score

    def forward(self,
                inputs: torch.Tensor,
                tags: torch.Tensor,
                mask: torch.ByteTensor = None) -> torch.Tensor:
        """
        Computes the log likelihood.
        """
        # pylint: disable=arguments-differ
        if mask is None:
            mask = torch.ones(*tags.size(), dtype=torch.long)

        log_denominator = self._input_likelihood(inputs, mask)
        log_numerator = self._joint_likelihood(inputs, tags, mask)

        # log-likelihood是两者相减的结果
        return torch.sum(log_numerator - log_denominator)

实现的重点是 l o g l i k e l i h o o d 中分子和分母部分的计算,在代码中分子部分是按照 ( A s t a r t , y 0 , ) , ( A y 0 , y 1 , P 0 ) , ( A y 1 , y 2 , P 1 ) , . . . , ( A y n 1 , y n , P n 1 ) , ( A y n , e n d , P n ) 这样的组合方式递推的;而在分母计算部分是按照 ( A s t a r t y 0 , P 0 ) , ( A y 0 , y 1 , P 1 ) , . . . , ( A y n 1 , y n , P n ) , ( A y n , e n d ) 两者结果本质是等价的,因为都是累加的结果!
  至于预测算法维特比算法,以后有机会再写吧,写博客太麻烦了!

猜你喜欢

转载自blog.csdn.net/jeryjeryjery/article/details/82346482