MNIST数据集,堪称深度学习界的Hello World,每个图像的大小为\(28*28\)对应\(0-10\)中的一个
import numpy as np
import torch
from torchvision import datasets
from torchvision.transforms import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
'''构建训练集测试集'''
num_workers = 0
batch_size = 32
transform = transforms.ToTensor()
train_data = datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
test_data = datasets.MNIST(root='./data', train=False,
download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
num_workers=num_workers)
'''定义网络架构'''
class classifier(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28*28, 512) #输入图片的大小为28*28
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 10)
self.dropout = nn.Dropout(0.2) #使用dropout方法来防止过拟合
def forward(self, x):
x = x.view(-1,28*28)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = F.relu(self.fc2(x))
x = self.dropout(x)
x = self.fc3(x)
return x
'''训练'''
model = classifier()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.002)
epochs = 32
for e in range(epochs):
training_loss = 0.0
for images, labels in train_loader:
optimizer.zero_grad() #梯度清零
output = model(images)
loss = criterion(output, labels) #计算损失
loss.backward() #反向传播
optimizer.step() #更新参数
training_loss += loss.item() * images.size(0) #loss.item()是平均损失,平均损失*batch_size=一次训练的损失
training_loss = training_loss/len(train_loader.dataset)
print('Epoch:{}\t Training Loss:{:.6f}'.format(e+1,training_loss))
'''计算在测试集上的表现'''
test_loss = 0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for image, labels in test_loader:
output = model(image)
loss = criterion(output, labels)
test_loss += loss.item() * image.size(0)
_, pred = torch.max(output,1) #_是最大的值的list,pred是最大的值的索引
correct = np.squeeze(pred.eq(labels.data.view_as(pred))) #如果预测对了对应的位置就是1
for i in range(batch_size):
label = labels.data[i]
class_correct[label] += correct[i].item()
class_total[label] += 1
test_loss = test_loss/len(test_loader.dataset)
print('Test Loss:{:.6f}'.format(test_loss))
for i in range(10):
if class_total[i] > 0 :
print('Test Accuracy of %5s: %2d%% (%2d/%2d)'%(
str(i),100* class_correct[i]/class_total[i],
np.sum(class_correct[i]), np.sum(class_total[i])))
else:
print('Test Accuracy of $5s: N/A (no training examples)'%(classes[i]))
print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)'%(
100.* np.sum(class_correct)/np.sum(class_total),
np.sum(class_correct), np.sum(class_total)))
下面是训练时的输出
Epoch:1 Training Loss:0.714305
Epoch:2 Training Loss:0.562318
Epoch:3 Training Loss:0.485287
Epoch:4 Training Loss:0.442476
Epoch:5 Training Loss:0.410877
Epoch:6 Training Loss:0.387815
Epoch:7 Training Loss:0.369312
Epoch:8 Training Loss:0.352989
Epoch:9 Training Loss:0.340237
Epoch:10 Training Loss:0.325559
Epoch:11 Training Loss:0.315037
Epoch:12 Training Loss:0.305838
Epoch:13 Training Loss:0.294698
Epoch:14 Training Loss:0.284443
Epoch:15 Training Loss:0.276668
Epoch:16 Training Loss:0.268770
Epoch:17 Training Loss:0.260248
Epoch:18 Training Loss:0.254518
Epoch:19 Training Loss:0.247402
Epoch:20 Training Loss:0.238849
Epoch:21 Training Loss:0.233701
Epoch:22 Training Loss:0.227878
Epoch:23 Training Loss:0.222812
Epoch:24 Training Loss:0.216647
Epoch:25 Training Loss:0.210697
Epoch:26 Training Loss:0.206513
Epoch:27 Training Loss:0.201959
Epoch:28 Training Loss:0.197595
Epoch:29 Training Loss:0.192502
Epoch:30 Training Loss:0.187927
Epoch:31 Training Loss:0.185286
Epoch:32 Training Loss:0.181954
以及测试集上的表现
Test Loss:0.177074
Test Accuracy of 0: 97% (5781/5923)
Test Accuracy of 1: 97% (6590/6742)
Test Accuracy of 2: 94% (5604/5958)
Test Accuracy of 3: 92% (5691/6131)
Test Accuracy of 4: 94% (5547/5842)
Test Accuracy of 5: 93% (5044/5421)
Test Accuracy of 6: 96% (5739/5918)
Test Accuracy of 7: 95% (5988/6265)
Test Accuracy of 8: 92% (5425/5851)
Test Accuracy of 9: 92% (5511/5949)
Test Accuracy (Overall): 94% (56920/60000)