mnist手写数据集识别(pytorch)

Mnist数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST数据集一共有7万张图片,其中6万张是训练集,1万张是测试集。每张图片是28 × 28 的0 − 9的手写数字图片组成。每个图片是黑底白字的形式,黑底用0表示,白字用0-1之间的浮点数表示,越接近1,颜色越白。

训练数据集:train-images-idx3-ubyte.gz (9.45 MB,包含60,000个样本)。
训练数据集标签:train-labels-idx1-ubyte.gz(28.2 KB,包含60,000个标签)。
测试数据集:t10k-images-idx3-ubyte.gz(1.57 MB ,包含10,000个样本)。
测试数据集标签:t10k-labels-idx1-ubyte.gz(4.43 KB,包含10,000个样本的标签)。

训练数据为 28x28像素,单通道的图片。

标签为:0-9的常数 

import os
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import torch.nn.functional as F
import numpy as np

learning_rate = 1e-4
keep_prob_rate = 0.7 #
max_epoch = 3
BATCH_SIZE = 50

DOWNLOAD_MNIST = False
if not(os.path.exists('./mnist/')) or not os.listdir('./mnist/'):
    # not mnist dir or mnist is empyt dir
    DOWNLOAD_MNIST = True


train_data = torchvision.datasets.MNIST(root='./mnist/',train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST,)
train_loader = Data.DataLoader(dataset = train_data ,batch_size= BATCH_SIZE ,shuffle= True)  #[500,1,28,28]
test_data = torchvision.datasets.MNIST(root = './mnist/',train = False)

test_x = Variable(torch.unsqueeze(test_data.test_data,dim  = 1),volatile = True).type(torch.FloatTensor)[:500]/255.
test_y = test_data.test_labels[:500].numpy() #500个0-9的数

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d( # ???
                # patch 7 * 7 ; 1  in channels ; 32 out channels ; ; stride is 1
                # padding style is same(that means the convolution opration's input and output have the same size)
                in_channels=1,  
                out_channels=32,
                kernel_size=7,
                stride= 1,          #  O = (I - K + 2P)/ S +1 i,input = [50,32,28,28] o = [28-7+6]+1 = 28 -> [50,32,28,28]
                padding=3,
            ),
            nn.ReLU(),        # activation function
            nn.MaxPool2d(2),  # pooling operation  ->   [50,32,14,14]
        )
        self.conv2 = nn.Sequential( # ???
            # line 1 : convolution function, patch 5*5 , 32 in channels ;64 out channels; padding style is same; stride is 1
            # line 2 : choosing your activation funciont
            # line 3 : pooling operation function.
            nn.Conv2d(
                in_channels = 32,
                out_channels = 64,
                kernel_size = 5,
                stride = 1,
                padding  = 2,  #上下左右假两列,执行卷积前,图像尺寸为:[18,18] 卷完尺寸变为[14,14]
            ),
            nn.ReLU(),
            nn.MaxPool2d(2),   # [50,64,7,7]
        )
        self.out1 = nn.Linear( 7*7*64 , 1024 , bias= True)   # full connection layer one
        self.dropout = nn.Dropout(keep_prob_rate)
        self.out2 = nn.Linear(1024,10,bias=True)



    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(-1, 64*7*7)  # flatten the output of coonv2 to (batch_size ,32 * 7 * 7)    # ???
        out1 = self.out1(x)
        out1 = F.relu(out1)
        out1 = self.dropout(out1)
        out2 = self.out2(out1)
        output = F.softmax(out2)
        return output


def test(cnn):
    global prediction
    y_pre = cnn(test_x)
    # print(len(y_pre[0]))   # len(y_pre) = 500,len(y_pre[0]) = 10
    _,pre_index= torch.max(y_pre,1)  #torch.max 返回两个tensor,第一个为,没行的最大数。第二个为最大值的索引。
    pre_index= pre_index.view(-1)
    #prediction = pre_index
    prediction = pre_index.data.numpy() # tensor->ndarry
    #print(type(pre_index))
    correct  = np.sum(prediction == test_y)
    return correct / 500.0

def train(cnn):
    optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate )
    loss_func = nn.CrossEntropyLoss()
    for epoch in range(max_epoch):
        for step, (x_, y_) in enumerate(train_loader):
            #x ,y= Variable(x_),Variable(y_)
            x,y= x_,y_
            output = cnn(x)  
            loss = loss_func(output,y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if step != 0 and step % 20 ==0:
                print("=" * 10,step,"="*5,"="*5, "test accuracy is ",test(cnn) ,"=" * 10 )

if __name__ == '__main__':
    cnn = CNN()
    train(cnn)

输出:

猜你喜欢

转载自blog.csdn.net/qq_42018521/article/details/130366968