1. pytorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
import numpy as np
import pandas as pd
from sklearn.preprocessing import label_binarize
from tensorboardX import SummaryWriter
from sklearn.metrics import roc_curve, auc,average_precision_score
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
nn.MaxPool2d(2, 2))
self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
nn.BatchNorm1d(120), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(120, 84),
nn.BatchNorm1d(84),
nn.ReLU(),
nn.Linear(84, 10))
# 最后的结果一定要变为 10,因为数字的选项是 0 ~ 9
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size()[0], -1)
x = self.fc1(x)
x = self.fc2(x)
return x
def get_w_value(array_value):
array_w_matrix = []
for row in array_value:
array_w_matrix.append([np.sum(row)])
return np.array(array_w_matrix)
def confusion_matrix_plot(model,test_loader,filename):
"""
绘制混淆矩阵
"""
conf_matrix = torch.zeros(10, 10)
real_labels = []
pred_labels = []
for batch_images, batch_labels in iter(test_loader):
# print(batch_labels)
with torch.no_grad():
if torch.cuda.is_available():
batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()
out = model(batch_images)
prediction = torch.max(out, 1)[1]
#print (prediction,batch_labels)
batch_labels = batch_labels.to("cpu").tolist()
prediction = prediction.to("cpu").tolist()
real_labels = real_labels + batch_labels
pred_labels = pred_labels + prediction
confusion_numpy = confusion_matrix(real_labels,pred_labels)
plt.rcParams['font.sans-serif'] = ['SimHei'] #宋体
plt.rcParams['font.size'] = 12 #设置字体大小
conf_numpy = confusion_numpy.tolist()
matrix_data = conf_numpy/get_w_value(conf_numpy)
# 将矩阵转化为 DataFrame
conf_df = pd.DataFrame(matrix_data, index=["0","1","2","3","4","5","6","7","8","9"] ,columns=["0","1","2","3","4","5","6","7","8","9"])
sns.heatmap(conf_df, annot=True, fmt=".2f", cmap="gist_gray_r",annot_kws={"size":16})
plt.xticks(rotation=0)
plt.yticks(rotation=360)
plt.tight_layout()
plt.title('Confusion Matrix')
plt.ylabel('True Label',fontsize=14)
plt.xlabel('Predict Label',fontsize=14)
plt.savefig(filename)
plt.show()
return confusion_numpy
def draw_loss(Loss_list):
# 我这里迭代了200次,所以x的取值范围为(0,200),然后再将每次相对应的准确率以及损失率附在x上
x1 = [ i*100 for i in range(len(Loss_list))]
y1 = Loss_list
#print(y1)
plt.cla()
plt.title('Train loss vs. epoches', fontsize=20)
plt.plot(x1, y1, '.-')
plt.xlabel('epoches', fontsize=20)
plt.ylabel('Train loss', fontsize=20)
plt.grid()
plt.savefig("./Train_loss.png")
plt.show()
def draw_acc(acc_list):
# 我这里迭代了200次,所以x的取值范围为(0,200),然后再将每次相对应的准确率以及损失率附在x上
x1 = [ i*100 for i in range(len(acc_list))]
y1 = acc_list
#print(y1)
plt.cla()
plt.title('Train accuracy vs. epoch', fontsize=20)
plt.plot(x1, y1, '.-')
plt.xlabel('epoch', fontsize=20)
plt.ylabel('Train accuracy', fontsize=20)
plt.grid()
plt.savefig("./Train _accuracy.png")
plt.show()
def roc_plot(model,test_loader,color=["blue","green","yellow","red","orange","gray","pink","black","purple","blue"]):
real_labels = []
pred_labels = []
for batch_images, batch_labels in iter(test_loader):
# print(batch_labels)
with torch.no_grad():
if torch.cuda.is_available():
batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()
out = model(batch_images)
prediction = torch.max(out, 1)[1]
#print (prediction,batch_labels)
batch_labels = batch_labels.to("cpu").tolist()
prediction = prediction.to("cpu").tolist()
real_labels = real_labels + batch_labels
pred_labels = pred_labels + prediction
#pred_labels = np.array(pred_labels).reshape(-1,1)
#real_labels = np.array(real_labels).reshape(-1,1)
classiers_number = 10
pred_labels = label_binarize(pred_label