文本摘要项目中对loss进行mask处理,损失为0的,代表不进行梯度更新,打pad后的<PAD>如何处理?

文本摘要项目中对loss进行mask处理,损失为0的,代表不进行梯度更新,打pad后的如何处理?

代码如下:

def loss_function(pred, real):  # [64,40,32217], [64, 40]

    # 求pad unk的索引值
    pad_index = word_to_id['<PAD>']  # pad_index = 40
    unk_index = word_to_id['<UNK>']  # pad_index = 42

    # 根据真实标签 求pad_mask unk_mask
    pad_mask = torch.eq(real, pad_index)  # [64, 40] --> [[False, False, ....]]
    unk_mask = torch.eq(real, unk_index)  # [64, 40] --> [[False, False, ....]]
    # 求mask矩阵 有效数据1 无效数据0, (取反后0的位置就对应pad或者unk 的位置)
    mask = torch.logical_not(torch.logical_or(pad_mask, unk_mask))  # mask:[64,40] [[True, True, ...False...], ...]

    # 计算损失 对pred转置 N C放前头
    pred2 = pred.transpose(2, 1)  # [64,40,32217] ---> [64,32217,40]
    loss_ = criterion(pred2, real)  # [64,32217,40], [64,40] 注意细节

    # 对pad unk位置产生的损失 mask
    loss_ = loss_ * mask

    # 计算批次平均损失
    len = mask.sum()  # 对mask矩阵sum 相当于求所有有效单词的个数
    loss = torch.sum(loss_) / len

    # 返回批次平均损失
    return loss

对pad unk位置产生的损失 mask

loss_ = loss_ * mask

mask是由0,1组成的与loss相同shape的矩阵,loss_ = loss_ * mask 就代表着用已经算好的loss对其进行遮掩,遮掩掉’’ 、''对应的loss值,这样就对应位置loss值为0,代表着这部分不进行梯度的更新。

为什么对应位置loss值为0,就代表着这部分不进行梯度的更新呢?

原因:
l o s s = △ y = y ^ − y loss = \triangle y = \hat{y} - y loss=y=y^y
k = △ y / △ x = 0 k = \triangle y / \triangle x = 0 k=y/△x=0
loss值为0,就代表 △ y = 0 \triangle y = 0 y=0,然后 w = w = w= w − w - w α \alpha α △ y / △ x = w \triangle y / \triangle x = w y/△x=w
所以该权重没有更新。


当然在使用这一特性的时候,必须设置以下代码:

# 使用reduction=none形式交叉熵算 不使用默认计算均值 原因:需要手工屏蔽pad位置产生的损失
criterion = nn.CrossEntropyLoss(reduction='none')

猜你喜欢

转载自blog.csdn.net/wtl1992/article/details/131607789