topk关键代码预测部分
预测的主要内容是这一段对应的代码
for d in tqdm(valid_data):
text = '\n'.join(d[0])
summary = predict(text)
metrics = compute_metrics(summary, d[2])
for k, v in metrics.items():
total_metrics[k] += v
关键预测部分代码
summary = predict(text)
进入到预测部分的代码
def predict(text, topk=3):
# 抽取
texts = convert.text_split(text)
vecs = vectorize.predict(texts)
preds = extract.model.predict(vecs[None])[0, :, 0]
preds = np.where(preds > extract.threshold)[0]
summary = ''.join([texts[i] for i in preds])
# 生成
summary = seq2seq.autosummary.generate(summary, topk=topk)
# 返回
return summary
这个没啥好讲的内容,本质上就是先抽取出句向量,然后调用抽取模型进行抽取,最后使用seq2seq的预测部分进行生成新的摘要的内容,基本上就是前面的路子顺一边,然后生成新的摘要的内容