跟着龙龙老师的视频来敲的,然后有几行来自 https://www.cnblogs.com/liualexsone/p/11355217.html 这位朋友的教程,写的很好很好,谢谢
import torch import torch.nn as nn import torch.optim as optim import torchvision from pytorch__lesson.pytorch_mnist.utils import plot_curve,plot_image,one_hot import matplotlib.pyplot as plt # step1 load dataset batch_size=512 train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('mnist_data', train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('mnist_data/', train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ])), batch_size=batch_size, shuffle=False) # x,y=next(iter(train_loader)) # print(x.shape,y.shape,x.min(),x.max()) # plot_image(x,y,'image sample') # 完成网络的创建 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1=nn.Linear(784,256) self.fc2=nn.Linear(256,64) self.fc3=nn.Linear(64,10) self.relu=nn.ReLU() def forward(self, x): x=self.fc1(x) x=self.relu(x) x=self.fc2(x) x=self.relu(x) x=self.fc3(x) return x net=Net() optimizer=optim.SGD(net.parameters(),lr=0.01,momentum=0.99) losses_func=nn.MSELoss() correct=0 total=0 # 损失函数使用前要先进行生命 for epoch in range(5): for batch_index,(x,y) in enumerate(train_loader): x=x.view(x.size(0),28*28) output=net(x) onehot_y=one_hot(y) loss=losses_func(output,onehot_y) optimizer.zero_grad() loss.backward() optimizer.step() if batch_index%10==0: # print(epoch,batch_index,loss.item()) _, predicted = torch.max(output, 1) total+=y.size(0) correct+=(predicted == y).sum().item() print("The accuracy of total {} images: {}%".format(total, 100 * correct / total))
然后utils.py在这里
import torch from matplotlib import pyplot as plt def one_hot(label,depth=10): out=torch.zeros() def plot_curve(data): fig = plt.figure() plt.plot(range(len(data)), data, color='blue') plt.legend(['value'], loc='upper right') plt.xlabel('step') plt.ylabel('value') plt.show() def plot_image(img, label, name): fig = plt.figure() for i in range(6): plt.subplot(2, 3, i + 1) plt.tight_layout() plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none') plt.title("{}: {}".format(name, label[i].item())) plt.xticks([]) plt.yticks([]) plt.show() # 生成独热码的函数 def one_hot(label,depth=10): # 第0维代表的是batch_size out=torch.zeros(label.size(0),depth) idx=torch.LongTensor(label).view(-1,1) out.scatter_(dim=1,index=idx,value=1) return out
独热码的做法我在上一篇的博文介绍了