手写字识别——Tensorflow

一、环境

        本系列文章主要基于windows7,Anaconda。

        Anaconda是个很有用的工具,安装各种库文件都非常方便,除了网络卡顿导致安装失败,目前都没发现其他问题

二、写在前面

        本文主要基于TensorFlow中文社区的一系列文章进行学习和记录,对mnist数据集和神经网络相关原理进行介绍。

                        http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html

            

        上图显示了,神经网络训练的基本步骤,接下来将按照图中的几步来讲解。

三、mnist数据集

        MNIST是在机器学习领域中的一个经典问题。该问题解决的是把28x28像素的灰度手写数字图片识别为相应的数字,其中数字的范围从0到9.

        该数据集主要包含了以下四个部分:

文件 内容
train-images-idx3-ubyte.gz 训练集图片 - 55000 张 训练图片, 5000 张 验证图片
train-labels-idx1-ubyte.gz 训练集图片对应的数字标签
t10k-images-idx3-ubyte.gz 测试集图片 - 10000 张 图片
t10k-labels-idx1-ubyte.gz 测试集图片对应的数字标签

        可以通过以下的代码获取mnist数据集:

from tensorflow.examples.tutorials.mnist import input_data      #导入模块
mnist = input_data.read_data_sets("mnist_dd/", one_hot=True)    #下载数据集,读取数据

        获取数据集之后,我们就可以进行开始神经网络的构建了。以下

        在构建网络之前,我们先来看一下这几个压缩文件中到底有什么文件呢,解压完之后我们得到的是一个.idx3-ubyte文件,用notepad++打开,开到的是一系列的16进制数据。如下图所示(t10k-images.idx3-ubyte):

     

        前面的16byte(每4个byte)分别表示:                    

参数 十六进制 十进制
魔数 0x00000803 2051
图片数 0x00002710 10000
行像素点 0x0000001c 28
列像素点 0x0000001c 28

        后面的数据是按照每张图片28*28个像素点照顺序排列的。

        我们可以用下面的代码去读一下数据并显示

# -*- coding: utf-8 -*-
"""
Created on Wed Jun 13 14:38:18 2018
@author: ZCH
"""
import numpy as np   
import struct  
import matplotlib.pyplot as plt   
  
filename = 'MNIST_data/t10k-images.idx3-ubyte'  
binfile = open(filename,'rb')#以二进制方式打开  
buf = binfile.read() 

index = 0  
magic, numImages, numRows, numColums = struct.unpack_from('>IIII',buf,index)#读取4个32 int 大端存取 
print (magic,' ',numImages,' ',numRows,' ',numColums  )
index += struct.calcsize('>IIII')  
  
im = struct.unpack_from('>784B',buf,index)#每张图是28*28=784Byte,这里只显示第一张图  
index += struct.calcsize('>784B' )  

im = np.array(im)  
im = im.reshape(28,28)  
print( im ) 
  
fig = plt.figure()  
plt.imshow(im,cmap = 'binary')#黑白显示  
plt.show()  

        运行结果如下:

        白色的为0,黑色的为255

        代码讲解:      

    struct.unpack_from('>IIII',buf,index)

        调用struct模块对数据进行解析,'I'表示四个字节,四个则表示一次性读取16字节;‘B’表示一个字节,‘784B’表示读取784个字节;‘>’表示以大端格式存储数据,‘<’表示以小端格式存储数据;buf指文件内容;index指读取文件的起始位。

        我们也可以用以下的方式将图片数据读取出来,并且保存为图片。

# -*- coding: utf-8 -*-
"""
Created on Fri Jun 29 14:38:57 2018
@author: ZCH
"""
from PIL import Image
import struct  
  
filename = 'MNIST_data/t10k-images.idx3-ubyte'  
def readfile(file):
    fd = open(filename,'rb')#以二进制方式打开  
    buf = fd.read() 
    fd.close()

    index = 0  
    magic, numImages, numRows, numColums = struct.unpack_from('>IIII',buf,index)#读取4个32 int 大端存取 
    print (magic,' ',numImages,' ',numRows,' ',numColums  )
    index += struct.calcsize('>IIII')  
    for i in range(numImages):
        image = Image.new('L',(numColums,numRows))
        for x in range(numRows):
            for y in range(numColums):
                image.putpixel((y,x),int(struct.unpack_from('>B',buf,index)[0]))
                index += struct.calcsize('>B') 
        image.save('test/'+str(i)+'.png')
        
readfile(filename)

猜你喜欢

转载自blog.csdn.net/Smile_Smilling/article/details/80865603