数据集的基本结构
可以参考官方文档 web documantation。主要有三个类:Dataset, Sampler and DataLoader。
-
Dataset:
代表数据集的抽象类;所有其他数据集都应该继承它。所有的子类都应该覆盖len(提供数据集的大小)和getitem(支持范围从0到len(self)的整形索引)。 -
Sampler:
所有采样器的基准类;每个采样器子类必须提供iter方法,提供一种遍历数据集元素的索引的方法,以及一个返回迭代器长度的len方法。 -
DataLoader:
组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
简单的数据集类:
train_images_path = "./data/train_images"
train_labels_path = "./data/train_labels"
class RSDataset(Dataset):
def __init__(self, input_root, mode="train", debug = False):
super().__init__()
self.input_root = input_root
self.mode = mode
if debug == False:
self.input_ids = sorted(img for img in os.listdir(self.input_root))
else:
self.input_ids = sorted(img for img in os.listdir(self.input_root))[:500]
self.mask_transform = transforms.Compose([
transforms.Lambda(to_monochrome),
transforms.Lambda(to_tensor),
])
self.image_transform = transforms.Compose([
transforms.ToTensor(),
])
self.transform = DualCompose([
RandomFlip(),
RandomRotate90(),
Rotate(),
Shift(),
])
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
# at this point all transformations are applied and we expect to work with raw tensors
imageName = os.path.join(self.input_root,self.input_ids[idx])
image = np.array(cv2.imread(imageName), dtype=np.float32)
mask = np.array(cv2.imread(imageName.replace("train_images", "train_labels")))/255
h, w, c = image.shape
mask1 = np.zeros((h, w), dtype=int)
if self.mode == "train":
image, mask = self.transform(image, mask)
mask1 = mask[:,:,0]
return self.image_transform(image), self.mask_transform(mask1)
else:
mask1 = mask[:,:,0]
return self.image_transform(image), self.mask_transform(mask1)
###划分训练集和验证集
def build_loader(input_img_folder = "./data/train_images",
batch_size = 16,
num_workers = 4):
# Get correct indices
num_train = len(sorted(img for img in os.listdir(input_img_folder)))
indices = list(range(num_train))
seed(128381)
indices = sample(indices, len(indices))
split = int(np.floor(0.15 * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
#set up datasets
train_dataset = RSDataset(
"./data/train_images",
"./data/train_labels",
mode = "train",
)
val_dataset = RSDataset(
"./data/train_images",
"./data/train_labels",
mode="valid",
)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=True
)
return train_loader, valid_loader