Minist手写数据集测试
(个人实践笔记,如有纰漏,烦请指出)
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
input_size = 28 * 28 # image size of MNIST data
num_classes = 10
num_epochs = 10
batch_size = 100
learning_rate = 1e-3
# MNIST dataset
train_dataset = dsets.MNIST(root = '../../data_sets/mnist', #选择数据的根目录
train = True, # 选择训练集
transform = transforms.ToTensor(), #转换成tensor变量
download = False) # 不从网络上download图片
test_dataset = dsets.MNIST(root = '../../data_sets/mnist', #选择数据的根目录
train = False, # 选择训练集
transform = transforms.ToTensor(), #转换成tensor变量
download = False) # 不从网络上download图片
#加载数据
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
batch_size = batch_size,
shuffle = True) # 将数据打乱
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
batch_size = batch_size,
shuffle = True)
class Net(torch.nn.Module):
def __init__(self,n_input,n_hideen,n_output):
super(Net,self).__init__()
self.hidden = torch.nn.Linear(n_input,n_hideen)
self.output = torch.nn.Linear(n_hideen,n_output)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.output(x)
return x
#define the neural net
net = Net(n_input=28*28,n_hideen=100,n_output=10)
optimizer = torch.optim.SGD(net.parameters(),lr = 0.0510)#0.052
loss_func = torch.nn.CrossEntropyLoss()
for i, (images, labels) in enumerate(train_loader): #利用enumerate取出一个可迭代对象的内容
images = images.reshape(-1, 28*28)
labels = labels
out = net(images)
loss = loss_func(out, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 100 == 0:
print('current loss = %.5f' % loss.item())
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, 28*28)
labels = labels
output_t = net(images)
_, predicted = torch.max(output_t.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))
运行结果:
current loss = 2.30939 current loss = 1.12731 current loss = 0.61168 current loss = 0.53309 current loss = 0.53291 current loss = 0.51907 Accuracy of the network on the 10000 test images: 90.12 %
运行结果很不稳定,大小在60到90之间波动,自己分析原因可能如下:
1 . 未使用batch_size。
2. 全连接网络,仅一个隐藏层,未加卷积层等。