深度学习模型评估指标 Accuracy,Precision,Recall,ROC曲线
以二分类为例,进行说明:
注:
- 判别是否为正例只需要设一个概率阈值T,预测概率大于阈值T的为正类,小于阈值T的为负类,默认就是0.5。
- 如果减小阀值T,更多的样本会被识别为正类,这样可以提高正类的召回率,但同时也会带来更多的负类被错分为正类;
- 如果增加阈值T,则正类的召回率降低,精度增加。如果是多类,比如ImageNet1000分类比赛中的1000类,预测类别就是预测概率最大的那一类。
常用的几种评估指标:
1. 准确度: Accuracy = (TP + TN) / (TP + FN + FP + TN)
注:Top_1 Accuracy和Top_5 Accuracy,Top_1 Accuracy就是计算的Accuracy。而Top_5 Accuracy是给出概率最大的5个预测类别,只要包含了真实的类别,则判定预测正确。
2. 精确度:Precision = TP / (TP + FP)
3. 召回率:Recall = TP / (TP + FN)
多分类情况
4. 混淆矩阵
如果对于每一类,若想知道类别之间相互误分的情况,查看是否有特定的类别之间相互混淆,就可以用混淆矩阵画出分类的详细预测结果。对于包含多个类别的任务,混淆矩阵很清晰的反映出各类别之间的错分概率,如下图:
注: 横坐标表示预测分类,纵坐标表示标签分类,其中(i,j)表示第i类目标被分为第j类的概率,对角线的值越大越好。
代码实现:
实际图片的标签值labels, 预测的分类值predicted,二者转换为具体的标签值(而不是onehot值),矩阵具体的大小为:[number_total_pictures, 1], 然后将二者拼接为[number_total_pictures*2, 1], 得到的这个矩阵的每一行都代表一个(i,j)值,然后进行统计即可。代码如下:
def Confusion_mxtrix(labels, predicted, num_classes):
"""
混淆矩阵的函数定义
Args:
labels: [number_total_pictures,1]
predicted: [number_total_pictures,1]
num_classes: 分类数目
Returns: Confusion_matrix
"""
Cmatrixs = torch.zeros((num_classes,num_classes))
stacked = torch.stack((labels, predicted), dim=1)
for s in stacked:
a, b = s.tolist()
Cmatrixs[a, b] = Cmatrixs[a, b] + 1
return Cmatrixs
def plot_confusion_matrix(cm, savename, title='Confusion Matrix'):
classes = ('airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
plt.figure(figsize=(12, 8), dpi=100)
np.set_printoptions(precision=2)
# 在混淆矩阵中每格的概率值
ind_array = np.arange(len(classes))
x, y = np.meshgrid(ind_array, ind_array)
for x_val, y_val in zip(x.flatten(), y.flatten()):
c = cm[y_val][x_val]
if c > 0.001:
plt.text(x_val, y_val, "%0.2f" % (c,), color='red', fontsize=15, va='center', ha='center')
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.binary)
plt.title(title)
plt.colorbar()
xlocations = np.array(range(len(classes)))
plt.xticks(xlocations, classes, rotation=90)
plt.yticks(xlocations, classes)
plt.ylabel('Actual label')
plt.xlabel('Predict label')
# offset the tick
tick_marks = np.array(range(len(classes))) + 0.5
plt.gca().set_xticks(tick_marks, minor=True)
plt.gca().set_yticks(tick_marks, minor=True)
plt.gca().xaxis.set_ticks_position('none')
plt.gca().yaxis.set_ticks_position('none')
plt.grid(True, which='minor', linestyle='-')
plt.gcf().subplots_adjust(bottom=0.15)
# show confusion matrix
plt.savefig(savename, format='png')
plt.show()
根据混淆矩阵Cm计算多分类的Accuracy,Precision & Recall
混淆矩阵Cm的对角线的值就是每一类的TP值,而FN=sum_row(Cm)-TP, FP=sum_col(Cm)-TP, TN=sum(Cm)-TP-FN-FP
def Evaluate(Cmatrixs):
"""for Precision & Recall"""
classes = ('airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
n_classes = Cmatrixs.size(0)
Prec, Rec = torch.zeros(n_classes+1), torch.zeros(n_classes+1)
sum_cmt_row = torch.sum(Cmatrixs,dim=1)#行的和
sum_cmt_col = torch.sum(Cmatrixs,dim=0)#列的和
print("----------------------------------------")
for i in range(n_classes):
TP = Cmatrixs[i,i]
FN = sum_cmt_row[i] - TP
FP = sum_cmt_col[i] - TP
# TN = torch.sum(Cmatrixs) - sum_cmt_row[i] - FP
Prec[i] = TP / (TP + FP)
Rec[i] = TP / (TP + FN)
print("%s"%(classes[i]).ljust(10," "),"Presion=%.3f%%, Recall=%.3f%%"%(Prec[i],Rec[i]))
Prec[-1] = torch.mean(Prec[0:-1])
Rec[-1] = torch.mean(Rec[0:-1])
print("ALL".ljust(10," "),"Presion=%.3f%%, Recall=%.3f%%" % (Prec[i], Rec[i]))
print("----------------------------------------")
# return Prec,Rec
5. ROC曲线
-
Receiver Operating Characteristic (ROC)曲线,评价一个分类器在不同阈值T下的表现情况。曲线横坐标False Positive rate(FPR), 纵坐标 True positive rate(TPR) ,描述True positive和False Positive之间的平衡。所以,绘制ROC曲线的一个重要的事情是要自己定义阈值T。
-
TPR = TP / (TP + FN) 分类器预测正类中实际正实例占所有正实例的比例
-
FPR = FP / (FP + TN) 分类器预测的正类中实际负实类占所有负实例的比例。
-
ROC曲线有4个关键的点:
- 点(0,0):FPR=TPR=0,分类器预测所有的样本都为负样本;
- 点(1,1):FPR=TPR=1,分类器预测所有的样本都为正样本;
- 点(0,1):FPR=0, TPR=1,此时FN=0且FP=0,所有的样本都正确分类;
- 点(1,0):FPR=1,TPR=0,此时TP=0且TN=0,最差分类器,避开了所有正确答案
def MyROC_i(outputs, labels, n=20):
'''
ROC曲线计算 绘制每一类的
Args:
outputs: [num_labels,num_classes]
labels: 标签值
n: 得到 n 个点之后绘图
Returns:plot_roc
'''
n_total, n_classes = outputs.size()
labels = labels.reshape(-1,1) # 行向量转为列向量
T = torch.linspace(0, 1, n)
TPR, FPR = torch.zeros(n, n_classes+1), torch.zeros(n, n_classes+1)
for i in range(n_classes):
for j in range(n):
mask_1 = outputs[:, i] > T[j]
TP_FP = torch.sum(mask_1)
mask_2 = (labels[:, -1] == i)
TP = torch.sum(mask_1 & mask_2)
FN = n_total / n_classes - TP
FP = TP_FP - TP
TN = n_total - n_total / n_classes - FP
TPR[j,i] = TP / (TP + FN)
FPR[j,i] = FP / (FP + TN)
TPR[:,-1] = torch.mean(TPR[:,0:-1],dim=1)
FPR[:, -1] = torch.mean(FPR[:, 0:-1], dim=1)
return TPR,FPR
def Plot_ROC_i(TPR, FPR, args, cfg):
for i in range(10+1):
if i==10: width=2
else: width=1
plt.plot(FPR[:,i],TPR[:,i],linewidth=width,label='classes_%d'%i)
plt.legend()
plt.title("ROC")
plt.grid(True)
plt.xlim(0,1)
plt.ylim(0,1)
plt.savefig(cfg.PARA.utils_paths.visual_path + args.net + '_ROC_i.png')
6. 主要的测试函数,输出所有的评估值
def test(net, epoch, test_loader, log, args, cfg):
with torch.no_grad():
labels_value, predicted_value, outputs_value = [],[],[]
correct = 0
total = 0
net.eval()
for i, data in enumerate(test_loader, 0):
images, labels = data
images = images.cuda()
labels_onehot = labels.cuda()
_, labels = torch.max(labels_onehot, 1)
outputs = net(images) #outputs:[100,10]
_, predicted = torch.max(outputs.data, 1)
# predicted = ToOnehots(predicted,cfg.PARA.train.num_classes)
total += labels.size(0)
correct += (predicted == labels).sum()#.item()
# Ready for matrixs
if i==0:
labels_value = labels
predicted_value = predicted
outputs_value = F.softmax(outputs.data,dim=1)
else:
labels_value = torch.cat((labels_value,labels),0)
predicted_value = torch.cat((predicted_value,predicted),0)
outputs_value = torch.cat((outputs_value,F.softmax(outputs.data,dim=1)),0)
correct = correct.cpu().numpy()
log.logger.info('epoch=%d,acc=%.5f%%' % (epoch, 100 * correct / total))
f = open("./cache/visual/"+args.net+"_test.txt", "a")
f.write("epoch=%d,acc=%.5f%%" % (epoch, 100 * correct / total))
f.write('\n')
log.logger.info("==> Get Confusion_Matrixs <==")
Cmatrixs = Confusion_mxtrix(labels_value,predicted_value,cfg.PARA.train.num_classes)
# print(Cmatrixs)
log.logger.info("==> Precision & Recall <==")
Evaluate(Cmatrixs) #get_Precision & Recall
log.logger.info("==> Plot_ROC <==")
TPR_i, FPR_i = MyROC_i(outputs_value, labels_value)
Plot_ROC_i(TPR_i, FPR_i,args,cfg)
f.close()
def main():
args = parser()
cfg = Config.fromfile(args.config)
log = Logger('./cache/log/' + args.net + '_testlog.txt', level='info')
log.logger.info('==> Preparing data <==')
test_loader = dataLoad(cfg)
log.logger.info('==> Loading model <==')
net = get_network(args,cfg).cuda()
# net = torch.nn.DataParallel(net, device_ids=cfg.PARA.train.device_ids)
log.logger.info("==> Waiting Test <==")
for epoch in range(100, 101):
# log.logger.info("==> Epoch:%d <=="%epoch)
checkpoint = torch.load('./cache/checkpoint/'+args.net+'/'+ str(epoch) +'ckpt.pth')
# checkpoint = torch.load('./cache/checkpoint/' + args.net + '/' + str(60) + 'ckpt.pth')
net.load_state_dict(checkpoint['net'])
test(net, epoch, test_loader, log, args, cfg)
log.logger.info('*'*25)
if __name__ == '__main__':
main()