以mnist为例:
下载数据集mnist,存放路径:/home/xxx/Downloads/Mnist/,有四个压缩文件,train-images, train-labels, test-images, test-labels。
tensorflow:
导入:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/home/xxx/Downloads/Mnist/",one_hot = True)
打印看一下:
tpytorch:
导入:
导入之前先修改mnist.py脚本,如图:
然后导入
from torchvision import datasets
打印看一下:
总结:
包含了60000张的训练图像和10000张的测试图像,每张像素为:784=28*28!
可以看出每个像素是8Byte,即0~255