构建基线数据,用pad=0填充构造与输入等长的数据
获取原数据的填充字符,起始,结束 使用tokenizer直接获取
PAD_ID = tokenizer.pad_token_id # 补齐符号 pad 的 id
SEP_ID = tokenizer.sep_token_id # 输入文本中句子的间隔符号 sep 的 id
# 置于文本初始位置的符号 cls 的 id,在 bert 预训练过程中其对应的输出用于预测句子是否相关
CLS_ID = tokenizer.cls_token_id
对原数据进行编码,用len获取长度
def construct_input_base(context, question):
# question 指问题,context 指需要从中寻求问题答案的文本
# 将 context 转化为 id
context_ids = tokenizer.encode(context, add_special_tokens=False)
# 将 question 转化为 id
question_ids = tokenizer.encode(question, add_special_tokens=False)
# 基线输入,与 input_ids 等长
base_input_ids = [CLS_ID] + [PAD_ID] * len(question_ids) + [SEP_ID] + \
[PAD_ID] * len(context_ids) + [SEP_ID]
return torch.tensor([base_input_ids], device=DEVICE)
tokenizer.encode(question, add_special_tokens=False)
使用encoder函数 获取编码
[1999, 2249, 1010, 20773, 2864, 2005, 2176, 6385, 2073, 1029]
调用construct_input_base函数
input_base = construct_input_base(context, question)
input_base