# coding: utf-8
import sklearn.metrics as sm
import pandas as pd
def error_analysis(src_file, pred_file, tgt):
'''
当某句话存在错误标签时,输出当前句子
'''
f = open(src_file, 'r', encoding='utf-8')
g = open(pred_file, 'r', encoding='utf-8')
h = open(tgt, 'w', encoding='utf-8')
src_lines = f.readlines()
pred_lines = g.readlines()
count = 0
label_list = []
word_list = []
p_label_list = []
for i, line in enumerate(src_lines):
if line != '\n':
line = line.strip('\ufeff\n').split()
p_line = pred_lines[i].strip('\ufeff\n').split()
word = line[0]
label = line[1]
word_list.append(word)
label_list.append(label)
p_label_list.append(p_line[1])
else:
# if count == 10:
# break
if label_list == p_label_list:
# print('label_list:',label_list)
# print('p_label_list:', p_label_list)
count += 1
else:
for w, l, p in zip(word_list, label_list, p_label_list):
if len(l) == 1:
h.write(w + '\t\t' + l + '\t\t\t' + p + '\n')
else:
h.write(w + '\t\t' + l + '\t' + p + '\n')
h.write('\n')
label_list = []
word_list = []
p_label_list = []
print('正确的有%d句话.'%count)
f.close()
g.close()
h.close()
def get_list(real, pred):
f = open(real, 'r', encoding='utf-8')
g = open(pred, 'r', encoding='utf-8')
src_lines = f.readlines()
pred_lines = g.readlines()
count = 0
real_list = []
pred_list = []
for i, line in enumerate(src_lines):
if line != '\n':
line = line.strip('\ufeff\n').split()
real_list.append(line[1])
p_line = pred_lines[i].strip('\ufeff\n').split()
pred_list.append(p_line[1])
else:
count += 1
print('count=%d'%count)
return real_list, pred_list
if __name__ == '__main__':
pred_file = 'pred_test.txt'
src_file = 'real_test.txt'
tgt_file = 'error_analysis.txt'
label_list = ['O','B-PER.NAM', 'I-PER.NAM', 'E-PER.NAM', 'S-PER.NAM', 'B-PER.NOM', 'I-PER.NOM', 'E-PER.NOM', 'S-PER.NOM',
'B-GPE.NAM', 'I-GPE.NAM', 'E-GPE.NAM', 'S-GPE.NAM', 'B-GPE.NOM', 'E-GPE.NOM',
'B-LOC.NAM', 'I-LOC.NAM', 'E-LOC.NAM', 'B-LOC.NOM', 'I-LOC.NOM', 'E-LOC.NOM', 'S-LOC.NOM',
'B-ORG.NAM', 'I-ORG.NAM', 'E-ORG.NAM', 'B-ORG.NOM', 'I-ORG.NOM', 'E-ORG.NOM' ]
label_dict = {
0: 'O', 1: 'B-PER.NAM', 2: 'I-PER.NAM', 3: 'E-PER.NAM', 4: 'S-PER.NAM', 5: 'B-PER.NOM', 6: 'I-PER.NOM', 7: 'E-PER.NOM', 8: 'S-PER.NOM', 9: 'B-GPE.NAM', 10: 'I-GPE.NAM', 11: 'E-GPE.NAM', 12: 'S-GPE.NAM', 13: 'B-GPE.NOM', 14: 'E-GPE.NOM', 15: 'B-LOC.NAM', 16: 'I-LOC.NAM', 17: 'E-LOC.NAM', 18: 'B-LOC.NOM', 19: 'I-LOC.NOM', 20: 'E-LOC.NOM', 21: 'S-LOC.NOM', 22: 'B-ORG.NAM', 23: 'I-ORG.NAM', 24: 'E-ORG.NAM', 25: 'B-ORG.NOM', 26: 'I-ORG.NOM', 27: 'E-ORG.NOM'}
# print(len(label_list))
# error_analysis(src_file, pred_file, tgt_file)
real_list, pred_list = get_list(src_file, pred_file)
# 生产混淆矩阵
# confusion_matrix = sm.confusion_matrix(real_list, pred_list, labels=label_list)
# 保存为excel
# df = pd.DataFrame(confusion_matrix)
# file = 'confusion_matrix.xlsx'
# # df.columns = label_list
# df.rename(index=label_dict,columns=label_dict, inplace=True)
# df.to_excel(file, index=True)
cls_report = sm.classification_report(real_list, pred_list, labels=label_list)
print(cls_report)
sklearn计算混淆矩阵
猜你喜欢
转载自blog.csdn.net/tailonh/article/details/112252474
今日推荐
周排行