基于编辑距离的中英文自动纠错

#!/usr/bin/env python3
# coding: utf-8
from pypinyin import *
from editing_distance import auto_correct_ch, auto_correct_en
from text_utils import is_chinese_string, is_alphabet_string
import config
import pickle
from flask import Flask, request
app = Flask(__name__)


class WordCorrect:
    def __init__(self):
        self.char_path = config.char_path
        self.model_path = config.model_path
        self.charlist = [word.strip() for word in open(self.char_path, "r", encoding="utf-8") if word.strip()]
        self.pinyin_dict = self.load_model(self.model_path)

    def load_model(self, model_path):
        # 读取pickle
        with open(model_path, 'rb') as handle:
            data = pickle.load(handle)  # Warning: If adding something here, also modifying saveDataset
        return data


@app.route('/test', methods=['get', 'post'])
def test():
    corrector = WordCorrect()
    error_phrase = request.values.get('err_phrase')  # 获取参数
    # 纯中文
    if is_chinese_string(error_phrase):
        word_pinyin = ','.join(lazy_pinyin(error_phrase))  # 拼音
        result = corrector.pinyin_dict.get(word_pinyin, 'na')  # 根据拼音来找
        if result == "na":
            print(auto_correct_ch(error_phrase))  # 根据编辑距离来找
            return auto_correct_ch(error_phrase)
        else:
            print(max(result, key=result.get))  # 返回拼音词典value最大的key
            return max(result, key=result.get)
    # 纯英文
    elif is_alphabet_string(error_phrase):
        print(auto_correct_en(error_phrase.lower()))
        return auto_correct_en(error_phrase)
    # 汉字 + 拼音
    else:
        word_pinyin = ','.join(lazy_pinyin(error_phrase))  # 拼音
        result = corrector.pinyin_dict.get(word_pinyin, 'na')  # 根据拼音来找


"""web接口测试模式"""
if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5000, debug='True')  # ip访问 http://ip地址:5000/test?err_phrase=中国国际航空航天博览会
#!/usr/bin/env python3
# coding: utf-8
import re

import pypinyin
from pypinyin import pinyin


def is_chinese(uchar):
    """判断一个unicode是否是汉字"""
    if '\u4e00' <= uchar <= '\u9fa5':
        return True
    else:
        return False


def is_chinese_string(string):
    """判断是否全为汉字"""
    for c in string:
        if not is_chinese(c):
            return False
    return True


def is_number(uchar):
    """判断一个unicode是否是数字"""
    if u'u0030' <= uchar <= u'u0039':
        return True
    else:
        return False


def is_alphabet(uchar):
    """判断一个unicode是否是英文字母"""
    if (u'u0041' <= uchar <= u'u005a') or (u'u0061' <= uchar <= u'u007a'):
        return True
    else:
        return False


def is_alphabet_string(string):
    """判断是否全部为英文字母"""
    for c in string:
        if c < 'a' or c > 'z':
            return False
    return True


def is_other(uchar):
    """判断是否非汉字,数字和英文字符"""
    if not (is_chinese(uchar) or is_number(uchar) or is_alphabet(uchar)):
        return True
    else:
        return False


def B2Q(uchar):
    """半角转全角"""
    inside_code = ord(uchar)
    if inside_code < 0x0020 or inside_code > 0x7e:  # 不是半角字符就返回原来的字符
        return uchar
    if inside_code == 0x0020:  # 除了空格其他的全角半角的公式为:半角=全角-0xfee0
        inside_code = 0x3000
    else:
        inside_code += 0xfee0
    return chr(inside_code)


def Q2B(uchar):
    """全角转半角"""
    inside_code = ord(uchar)
    if inside_code == 0x3000:
        inside_code = 0x0020
    else:
        inside_code -= 0xfee0
    if inside_code < 0x0020 or inside_code > 0x7e:  # 转完之后不是半角字符返回原来的字符
        return uchar
    return chr(inside_code)


def stringQ2B(ustring):
    """把字符串全角转半角"""
    return "".join([Q2B(uchar) for uchar in ustring])


def uniform(ustring):
    """格式化字符串,完成全角转半角,大写转小写的工作"""
    return stringQ2B(ustring).lower()


def get_homophones_by_char(input_char):
    """
    根据汉字取同音字
    :param input_char:
    :return:
    """
    result = []
    # CJK统一汉字区的范围0x4E00-0x9FA5,20902个汉字
    for i in range(0x4e00, 0x9fa6):
        if pinyin([chr(i)], style=pypinyin.NORMAL)[0][0] == pinyin(input_char, style=pypinyin.NORMAL)[0][0]:
            result.append(chr(i))
    return result




import config
from pypinyin import *


# 构建词典{word:freq}
def construct_dict(file_path):
    word_freq = {}
    with open(file_path, "r", encoding="utf-8") as f:
        for line in f:
            info = line.split()
            word = info[0]
            frequency = info[1]
            word_freq[word] = frequency

    return word_freq


# 词典 {word:freq}
phrase_freq = construct_dict(config.dict_freq)

# print(type(phrase_freq))
# print(len(phrase_freq))


# 加载插入字的单个字的字典
def load_cn_words_dict(file_path):
    cn_words_dict = ""
    with open(file_path, "r", encoding="utf-8") as f:
        for word in f:
            cn_words_dict += word.strip()
    return cn_words_dict


# 编辑距离=1
def edits1(phrase, cn_words_dict):
    splits = [(phrase[:i], phrase[i:]) for i in range(len(phrase) + 1)]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
    replaces = [L + c + R[1:] for L, R in splits if R for c in cn_words_dict]
    
    return set(transposes + replaces)

# 返回字形与词典里相同的词语
def known(phrases):
    return set(phrase for phrase in phrases if phrase in phrase_freq)


def get_candidates_ch(error_phrase):
    candidates_1st_order = []
    candidates_2nd_order = []
    candidates_3nd_order = []

    error_pinyin = '/'.join(lazy_pinyin(error_phrase))  # 拼音
    # error_pinyin = pinyin.get(error_phrase, format="strip", delimiter="/").encode("utf-8")
    cn_words_dict = load_cn_words_dict(config.char_path)
    candidate_phrases = list(known(edits1(error_phrase, cn_words_dict)))

    for candidate_phrase in candidate_phrases:
        candidate_pinyin = '/'.join(lazy_pinyin(candidate_phrase))
        if candidate_pinyin == error_pinyin:
            candidates_1st_order.append(candidate_phrase)
        elif candidate_pinyin.split("/")[0] == error_pinyin.split("/")[0]:
            candidates_2nd_order.append(candidate_phrase)
        else:
            candidates_3nd_order.append(candidate_phrase)

    return candidates_1st_order, candidates_2nd_order, candidates_3nd_order


def get_candidates_en(error_phrase):
    candidates_1st_order = []
    candidates_2nd_order = []

    cn_words_dict = load_cn_words_dict(config.char_en_path)
    candidate_phrases = list(known(edits1(error_phrase, cn_words_dict)))
    candidate_phrases_ = list(known(edits2(error_phrase, cn_words_dict)))

    for candidate_phrase in candidate_phrases:
        candidates_1st_order.append(candidate_phrase)

    for candidate_phrase in candidate_phrases_:
        candidates_2nd_order.append(candidate_phrase)

    return candidates_1st_order, candidates_2nd_order


# 测试词语
def auto_correct_ch(error_phrase):
    c1_order, c2_order, c3_order = get_candidates_ch(error_phrase)
    # print c1_order, c2_order, c3_order
    if c1_order:
        return max(c1_order, key=phrase_freq.get)
    elif c2_order:
        return max(c2_order, key=phrase_freq.get)
    elif c3_order:
        return max(c3_order, key=phrase_freq.get)
    else:
        return error_phrase  # 如果编辑距离与词典没有匹配的,则返回原值


# 测试英文
def auto_correct_en(error_phrase):
    c1_order, c2_order = get_candidates_en(error_phrase)
    # print c1_order, c2_order, c3_order
    if c1_order:
        return max(c1_order, key=phrase_freq.get)
    elif c2_order:
        return max(c2_order, key=phrase_freq.get)
    else:
        return error_phrase  # 如果编辑距离与词典没有匹配的,则返回原值


# 测试句子
def auto_correct_sentence(error_sentence, verbose=True):
    import jieba
    import string

    PUNCTUATION_LIST = string.punctuation
    PUNCTUATION_LIST += "。,?:;{}[]‘“”《》/!%……()"

    jieba_cut = jieba.cut(error_sentence, cut_all=False)
    seg_list = "\t".join(jieba_cut).split("\t")
    correct_sentence = ""

    for phrase in seg_list:
        correct_phrase = phrase
        # 检查句子是否为标点符号
        if phrase not in PUNCTUATION_LIST:
            # 检查字典里的词组是否拼写错误
            if phrase not in phrase_freq.keys():
                correct_phrase = auto_correct_ch(phrase)
                if verbose:
                    print(phrase, correct_phrase)

        correct_sentence += correct_phrase

    if verbose:
        print(correct_sentence)


# 计算编辑距离
def levenshtein3(s, t):
    ''' From Wikipedia article; Iterative with two matrix rows. '''
    if s == t:
        return 0
    elif len(s) == 0:
        return len(t)
    elif len(t) == 0:
        return len(s)
    v0 = [None] * (len(t) + 1)
    v1 = [None] * (len(t) + 1)
    for i in range(len(v0)):
        v0[i] = i
    for i in range(len(s)):
        v1[0] = i + 1
        for j in range(len(t)):
            cost = 0 if s[i] == t[j] else 1
            v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
        for j in range(len(v0)):
            v0[j] = v1[j]

    return v1[len(t)]


def read_name():
    # name = []
    # fOpen = open(config.dict_freq, "r", encoding='UTF-8')
    # for line in fOpen:
    #     # line_ = list(jieba.cut(line))
    #     name.append(line.split(" ")[0])
    # return name

    fOpen = open(config.dict, "r", encoding='UTF-8')
    return fOpen.readlines()

猜你喜欢

转载自blog.csdn.net/qq_20780183/article/details/85599281