混淆矩阵
混淆矩阵可以用来观察模型预测结果,分析模型对各种类别预测能力。它由实际情况和预测情况两个维度构成。
举个例子,你有10张图片,其中有4只狗图片,6只猫图片,这是样本数据的实际情况。此时,你用模型来预测这些样本数据,预测结果为,4只狗图片中3只被预测为狗,1只被预测为猫,而6只猫图片中3只被预测为猫,3只被预测为狗。可以使用下表所示的混淆矩阵来表示模型的预测情况:
实际 \ 预测 | 狗-预测 | 猫-预测 |
---|---|---|
狗-实际 | 3 | 1 |
猫-实际 | 3 | 3 |
Python格式化输出混淆矩阵
想要实现的效果为:输入一个混淆矩阵和类别名称,能够打印输出整洁美观的内容。
在Python中,混淆矩阵可以用numpy矩阵来表示。在矩阵中,每一行都代表着实际情形,而每一列则是模型对数据的预测结果。类别名称作为表头使用,可以很直观地看出分类结果地混淆情况,究竟是哪些类别之间错分较多。类别名称可以用列表来表示。
# -*-coding: utf-8 -*-
import numpy as np
"""
Python格式化输出混淆矩阵
:param confusion_matrix: 混淆矩阵,一个numpy矩阵,元素均为整型
:param type_name: 类别名称,一个字符串列表,默认为None
:param placeholder_length: 占位符宽度,即每个数字占几位,用于对齐,默认为5
"""
def format_print_confusion_matrix(confusion_matrix, type_name=None, placeholder_length=5):
if type_name != None:
type_name.insert(0, 'T \ P') # 头部插入一个元素补齐
for tn in type_name:
fm = '%'+str(placeholder_length)+'s'
print(fm%tn,end='') # 不换行输出每一列表头
print('\n')
for i,cm in enumerate(confusion_matrix):
if type_name != None:
fm = '%'+str(placeholder_length)+'s'
print(fm%type_name[i+1],end='') # 不换行输出每一行表头
for c in cm:
fm = '%'+str(placeholder_length)+'d'
print(fm%c,end='') # 不换行输出每一行元素
print('\n')
if __name__ == '__main__':
confusion_matrix_example = np.array([[3,1],
[3,3]])
type_name_example = ['狗', '猫']
format_print_confusion_matrix(confusion_matrix_example, type_name_example,7)
"""
T \ P 狗 猫
狗 3 1
猫 3 3
"""