读取并分析如下四个文件
‘train-images-idx3-ubyte’
‘train-labels-idx1-ubyte’
‘t10k-images-idx3-ubyte’
‘t10k-labels-idx1-ubyte’
import numpy as np
import os
class Mnist(object):
def __init__(self):
self.dataname = "Mnist"
self.dims = 28*28
self.shape = [28 , 28 , 1]
self.image_size = 28
self.data, self.data_y = self.load_mnist()
def load_mnist(self):
data_dir = os.path.join("./data", "mnist")
fd = open(os.path.join(data_dir, 'train-images-idx3-ubyte'))
loaded = np.fromfile(file=fd , dtype=np.uint8)
trX = loaded[16:].reshape((60000, 28 , 28 , 1)).astype(np.float)
point = loaded[:16]
print(point)
fd = open(os.path.join(data_dir, 'train-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
trY = loaded[8:].reshape((60000)).astype(np.float)
point = loaded[:8]
print(point)
fd = open(os.path.join(data_dir, 't10k-images-idx3-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
teX = loaded[16:].reshape((10000, 28 , 28 , 1)).astype(np.float)
fd = open(os.path.join(data_dir, 't10k-labels-idx1-ubyte'))
loaded = np.fromfile(file=fd, dtype=np.uint8)
teY = loaded[8:].reshape((10000)).astype(np.float)
trY = np.asarray(trY)
teY = np.asarray(teY)
X = np.concatenate((trX, teX), axis=0)
y = np.concatenate((trY, teY), axis=0)
seed = 666
np.random.seed(seed)
np.random.shuffle(X)
np.random.seed(seed)
np.random.shuffle(y)
y_vec = np.zeros((len(y), 10), dtype=np.float)
for i, label in enumerate(y):
y_vec[i, int(y[i])] = 1.0
return X / 255., y_vec
if __name__ == "__main__":
mn_object = Mnist()
x = mn_object.data
y = mn_object.data_y