使用pickle模块存储猫狗数据

import os
import cv2
from sklearn.utils import shuffle
import numpy as np
from six.moves import cPickle as pickle

CLASS_NAME=["dogs","cats"]
#图像大小
img_size=64
#验证集大小
validation=0.2
#数据路径
train_path='D:/anicode/spyderworkspace/catanddog_class/training_data/'

def get_dataset(path):
    images=[]
    labels=[]
    for i,name in enumerate(CLASS_NAME):
        direct_name=path+name+'/'
        #listdir返回指定文件夹包含的文件或者文件的名字的列表
        direct=os.listdir(direct_name)
        for file in direct:
            img=cv2.imread(direct_name+file)
            img=cv2.resize(img,(img_size,img_size),0,0,cv2.INTER_LINEAR)
            img=img.astype(np.float32)
            img=np.multiply(img,1.0/255.0)
            images.append(img)
            labels.append(i)
    images=np.array(images)
    labels=np.array(labels)
    return images,labels

def save_pickle(filename):
    #获取数据
    train_dataset,train_labels=get_dataset(train_path)
    #对数据进行洗牌
    train_dataset,train_labels=shuffle(train_dataset,train_labels)
   
    #对训练集进行划分 
    validation_size=int(validation*train_dataset.shape[0])
    validation_images=train_dataset[:validation_size]
    validation_label=train_labels[:validation_size]
    
    train_images=train_dataset[validation_size:]
    train_label=train_labels[validation_size:]
    f=open(filename,'wb')
    save={
        'train_images':train_images,
        'train_label':train_label,
        'validation_images':validation_images,
        'validation_label':validation_label,
            }
    #将数据保存到文件中
    pickle.dump(save,f,pickle.HIGHEST_PROTOCOL)
    f.close()
    
def main():
    save_pickle('animal.pickle')


if __name__=='__main__':
    main()
#读取存储的数据
with open('D:/anicode/spyderworkspace/catanddog_class/animal.pickle','rb') as pk:
        pick=pickle.load(pk)
        train_images=pick['train_images']
        train_label=pick['train_label']
        validation_images=pick['validation_images']
        validation_label=pick['validation_label']
        train_size=train_images.shape[0]
        valid_size=validation_images.shape[0]
        del pick
        print("Training Set:",train_images.shape,train_label.shape)
        print("Validation Set:",validation_images.shape,validation_label.shape)

kaggle猫狗数据 链接:https://pan.baidu.com/s/10DcQNn_LybUDqffdSzzKEg 密码:sua2

猜你喜欢

转载自blog.csdn.net/sinat_38998284/article/details/80991584