import os
import cv2
from sklearn.utils import shuffle
import numpy as np
from six.moves import cPickle as pickle
CLASS_NAME=["dogs","cats"]
#图像大小
img_size=64
#验证集大小
validation=0.2
#数据路径
train_path='D:/anicode/spyderworkspace/catanddog_class/training_data/'
def get_dataset(path):
images=[]
labels=[]
for i,name in enumerate(CLASS_NAME):
direct_name=path+name+'/'
#listdir返回指定文件夹包含的文件或者文件的名字的列表
direct=os.listdir(direct_name)
for file in direct:
img=cv2.imread(direct_name+file)
img=cv2.resize(img,(img_size,img_size),0,0,cv2.INTER_LINEAR)
img=img.astype(np.float32)
img=np.multiply(img,1.0/255.0)
images.append(img)
labels.append(i)
images=np.array(images)
labels=np.array(labels)
return images,labels
def save_pickle(filename):
#获取数据
train_dataset,train_labels=get_dataset(train_path)
#对数据进行洗牌
train_dataset,train_labels=shuffle(train_dataset,train_labels)
#对训练集进行划分
validation_size=int(validation*train_dataset.shape[0])
validation_images=train_dataset[:validation_size]
validation_label=train_labels[:validation_size]
train_images=train_dataset[validation_size:]
train_label=train_labels[validation_size:]
f=open(filename,'wb')
save={
'train_images':train_images,
'train_label':train_label,
'validation_images':validation_images,
'validation_label':validation_label,
}
#将数据保存到文件中
pickle.dump(save,f,pickle.HIGHEST_PROTOCOL)
f.close()
def main():
save_pickle('animal.pickle')
if __name__=='__main__':
main()
#读取存储的数据
with open('D:/anicode/spyderworkspace/catanddog_class/animal.pickle','rb') as pk:
pick=pickle.load(pk)
train_images=pick['train_images']
train_label=pick['train_label']
validation_images=pick['validation_images']
validation_label=pick['validation_label']
train_size=train_images.shape[0]
valid_size=validation_images.shape[0]
del pick
print("Training Set:",train_images.shape,train_label.shape)
print("Validation Set:",validation_images.shape,validation_label.shape)
kaggle猫狗数据 链接:https://pan.baidu.com/s/10DcQNn_LybUDqffdSzzKEg 密码:sua2