import os
import glob
import random
import csv
import tensorflow as tf
def load_pokemon(root,mode ='train'):
name2label = {}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
#因为未指定index,所以key默认为0,1,2...
name2label[name] =len(name2label.keys())
images,labels = load_csv(root,'images.csv',name2label)
if mode=='train':
images = images[:int(0.6*len(images))]
labels = labels[:int(0.6*len(labels))]
elif mode=='val':
images =images[int(0.6*len(images)):int(0.8*len(images))]
labels = labels[int(0.6*len(labels)):int(0.8*len(labels))]
else:
images = images[int(0.8*len(images)):]
labels = labels[int(0.8*len(labels)):]
return images,labels,name2label
def load_csv(root,filename,name2label):
if not os.path.exists(os.path.join(root,filename)):
images =[]
for name in name2label.keys():
images +=glob.glob(os.path.join(root,name,'*.png'))
images +=glob.glob(os.path.join(root,name,'*.jpg'))
images +=glob.glob(os.path.join(root,name,'*.jpeg'))
print(len(images),images)
random.shuffle(images)
with open(os.path.join(root,filename),mode = 'w',newline ='') as f:
writer =csv.writer(f)
for img in images:
#查看类别,bulasaur
name = img.split(os.sep)[-2]
label = name2label[name]
writer.writerow([img,label])
print('written into csv file:',filename)
images,labels = [],[]
with open(os.path.join(root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img,label = row
label = int(label)
images.append(img)
labels.append(label)
return images,labels
images,labels,table =load_pokemon('pokeman','train')
print('images:',len(images),images)
print('labels:',len(labels),labels)
print('table:',table)
#根据路径找图片
def preprocess(x,y):
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x,channels = 3)
x = tf.image.resize(x,[244,244])
#数据增强
x = tf.image.random_flip_left_right(x)
x = tf.image.random_crop(x,[244,244,3])
#转换为张亮
x = tf.cast(x,dtype = tf.float32)/255.
x = normalize(x)
y = tf.convert_to_tensor(y)
return x,y
img_mean = tf.constant([0.485,0.456,0.406])
img_std = tf.constant([0.229,0.224,0.225])
#标准化
def normalize(x,mean = img_mean,std = img_std):
x = (x-mean)/std
return x
def denormalize(x,mean = img_mean,std = img_std):
x =x*std+mean
return x
batchsz = 128
# 创建训练集 Dataset 对象
images, labels, table = load_pokemon('pokeman',mode='train')
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz)
# 创建验证集 Dataset 对象
images2, labels2, table = load_pokemon('pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2))
db_val = db_val.map(preprocess).batch(batchsz)
# 创建测试集 Dataset 对象
images3, labels3, table = load_pokemon('pokeman',mode='test')
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3))
db_test = db_test.map(preprocess).batch(batchsz)
Tensorflow创建数据集
猜你喜欢
转载自blog.csdn.net/weixin_40539952/article/details/103434368
今日推荐
周排行