bert transformers mask预测,预测缺失的mask字

今天需要用到transformer里面的bert进行mask预测,我这里分享一下我的代码:

import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

text = '[CLS] 我 是 [MASK] 国 人 [SEP]'
tokenized_text = tokenizer.tokenize(text)
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)

# Create the segments tensors.
segments_ids = [0] * len(tokenized_text)

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-chinese')
model.eval()
masked_index = tokenized_text.index('[MASK]')
# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)

predicted_index = torch.argmax(predictions[0][0][masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
print(predicted_token)

预测的结果:

.....
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
中

还不错吧,挺神奇的。transformer版本为:

transformers                       3.0.2

参考文献

[1].predicting-missing-words-in-a-sentence-natural-language-processing-model. https://stackoverflow.com/questions/54978443/predicting-missing-words-in-a-sentence-natural-language-processing-model

猜你喜欢

转载自blog.csdn.net/w5688414/article/details/109284618