语义分割中的数据生成器dataloader(pytorch版)

数据集的基本结构

可以参考官方文档 web documantation。主要有三个类:Dataset, Sampler and DataLoader。

  • Dataset:
    代表数据集的抽象类;所有其他数据集都应该继承它。所有的子类都应该覆盖len(提供数据集的大小)和getitem(支持范围从0到len(self)的整形索引)。

  • Sampler:
    所有采样器的基准类;每个采样器子类必须提供iter方法,提供一种遍历数据集元素的索引的方法,以及一个返回迭代器长度的len方法。

  • DataLoader:
    组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。

简单的数据集类:

train_images_path = "./data/train_images"
train_labels_path = "./data/train_labels"

class RSDataset(Dataset):
    def __init__(self, input_root, mode="train", debug = False):
        super().__init__()
        self.input_root = input_root
        self.mode = mode
        if debug == False:
            self.input_ids = sorted(img for img in os.listdir(self.input_root))
        else:
            self.input_ids = sorted(img for img in os.listdir(self.input_root))[:500]
        
        self.mask_transform = transforms.Compose([
            transforms.Lambda(to_monochrome),
            transforms.Lambda(to_tensor),
        ])
            
        self.image_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        self.transform = DualCompose([
                RandomFlip(),
                RandomRotate90(),
                Rotate(),
                Shift(),
            ])
        
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        # at this point all transformations are applied and we expect to work with raw tensors
        imageName = os.path.join(self.input_root,self.input_ids[idx])
        image = np.array(cv2.imread(imageName), dtype=np.float32)
        mask = np.array(cv2.imread(imageName.replace("train_images", "train_labels")))/255
        h, w, c = image.shape
        mask1 = np.zeros((h, w), dtype=int)

        if self.mode == "train":
            image, mask  =  self.transform(image, mask)
            mask1 = mask[:,:,0]
            return self.image_transform(image), self.mask_transform(mask1)
        else:
            mask1 = mask[:,:,0]
            return self.image_transform(image), self.mask_transform(mask1)


###划分训练集和验证集
def build_loader(input_img_folder = "./data/train_images",
                 batch_size = 16,
                 num_workers = 4):
    # Get correct indices
    num_train = len(sorted(img for img in os.listdir(input_img_folder)))
    indices = list(range(num_train))
    seed(128381)
    indices = sample(indices, len(indices))
    split = int(np.floor(0.15 * num_train))

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    #set up datasets
    train_dataset = RSDataset(
        "./data/train_images",
        "./data/train_labels",
        mode = "train",
    )

    val_dataset = RSDataset(
        "./data/train_images",
        "./data/train_labels",
        mode="valid",
    )

    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=True
    )

    valid_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=True
    )

    return train_loader, valid_loader
发布了33 篇原创文章 · 获赞 3 · 访问量 5547

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/104197638