计算前向传播函数
def squad_pos_forward_func(input_ids, attention_mask, position=0):
pred = model(input_ids, attention_mask) # 获取预测结果
pred = pred[position] # 当 position 为 0 时,取的是起始位置所在的分布,为 1 时,取的是结束位置所在的分布
return pred.max(1).values # 取分布的最大值,即预测结果
pred[0],pred[1]就是起始位置的分布,是一个矩阵
from captum.attr import LayerIntegratedGradients
lig = LayerIntegratedGradients(
squad_pos_forward_func, model.distilbert.embeddings) # 输入前向函数以及模型中的某一层
# 对于输出答案初始位置,词向量层的贡献计算
attributions_start, delta_start = lig.attribute(inputs=input_ids, baselines=input_base,
additional_forward_args=(
attention_mask, 0),
return_convergence_delta=True)
# 对于输出答案结束位置,词向量层的贡献计算
attributions_end, delta_end = lig.attribute(inputs=input_ids, baselines=input_base,
additional_forward_args=(
attention_mask, 1),
return_convergence_delta=True)
目的是获得所有token对结果的贡献,使用attributions_start.sum()后进行归一化
tensor([[[-0.0000e+00, -0.0000e+00, 0.0000e+00, ..., 0.0000e+00, 0.0000e+00, -0.0000e+00], [ 7.9740e-03, 1.1938e-03, 1.4542e-03, ..., -1.6658e-03, -2.3311e-04, -1.6646e-03], [-6.5184e-04, 8.4541e-04, 9.1006e-03, ..., 4.9416e-03, 4.0402e-04, -7.4628e-05], ..., [ 1.9171e-02, 8.3873e-03, 2.4527e-02, ..., -1.3171e-03, 3.0566e-02, 1.0291e-02], [ 6.1036e-04, -3.1783e-04, 1.2646e-03, ..., 5.8634e-04, 2.5525e-03, -1.6722e-04], [-0.0000e+00, 0.0000e+00, 0.0000e+00, ..., -0.0000e+00, 0.0000e+00, -0.0000e+00]]], dtype=torch.float64)
以上是每个token对答案起始位置的贡献
使用attributions_start.sum(dim=-1)/attributions_start.sum() 进行归一化
可视化
ground_truth = "New York's Roseland Ballroom"
ground_truth_tokens = tokenizer.encode(
ground_truth, add_special_tokens=False) # 对真实答案进行 encode
ground_truth_end_ind = input_ids[0].detach().tolist().index(
ground_truth_tokens[-1]) # 通过获取真实答案结束位置在所有输入中的位置得到结束位置
ground_truth_start_ind = ground_truth_end_ind - \
len(ground_truth_tokens) + 1 # 再获取答案初始位置
from captum.attr import visualization
print("各单词对于答案起始位置的影响:")
start_position_vis = visualization.VisualizationDataRecord(
attributions_start_sum, # 贡献
torch.max(torch.softmax(outputs["start_logits"][0], dim=0)),
start_pred.item(), # 预测的起始位置
ground_truth_start_ind, # 真实起始位置
str(ground_truth_start_ind), # 真实起始位置的字符串形式
attributions_start_sum.sum(), # 所有贡献量的加和
all_tokens, # 输入的 token
delta_start) # 计算误差
visualization.visualize_text([start_position_vis])