代码:绘制多个子图
主要返回值
ax.flat 找了好久不知道什么意思,先记在这里吧
fig, ax = plt.subplots(4, 6) for i, axi in enumerate(ax.flat): axi.imshow(Xtest[i].reshape(62, 47), cmap='bone') axi.set(xticks=[], yticks=[]) axi.set_ylabel(faces.target_names[yfit[i]].split()[-1], color='black' if yfit[i] == ytest[i] else 'red') fig.suptitle('Predicted Names; Incorrect Labels in Red', size=14); plt.show()
最后要画出图必须plt.show,imshow只是对矩阵的处理,不作出图
二. 混淆矩阵热点图绘制:
from sklearn.metrics import confusion_matrix import seaborn as sns; mat = confusion_matrix(ytest, yfit) sns.heatmap(mat.T, square=True, annot=True, fmt='d', cbar=False, xticklabels=faces.target_names, yticklabels=faces.target_names) plt.ylabel('predicted label');
效果: