下面是全部的代码:
import os
import torch
import numpy as np
import scipy.misc as m
from PIL import Image
from torch.utils import data
from dataloaders.utils import recursive_glob, decode_segmap
from mypath import Path
class CityscapesSegmentation(data.Dataset):
def __init__(self, root=Path.db_root_dir('cityscapes'), split="train", transform=None):
self.root = root
self.split = split
self.transform = transform
self.files = {}
self.n_classes = 19
self.images_base = os.path.join(self.root, 'leftImg8bit', self.split)
self.annotations_base = os.path.join(self.root, 'gtFine', self.split)
self.files[split] = recursive_glob(rootdir=self.images_base, suffix='.png')
self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] # 16
self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] # 19
self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \
'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \
'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \
'motorcycle', 'bicycle'] # 20
self.ignore_index = 255
self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))
if not self.files[split]:
raise Exception("No files for split=[%s] found in %s" % (split, self.images_base))
print("Found %d %s images" % (len(self.files[split]), split))
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
img_path = self.files[self.split][index].rstrip()
lbl_path = os.path.join(self.annotations_base,
img_path.split(os.sep)[-2], # os.sep=='/' get city name
os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png')
_img = Image.open(img_path).convert('RGB')
_tmp = np.array(Image.open(lbl_path), dtype=np.uint8)
_tmp = self.encode_segmap(_tmp)
_target = Image.fromarray(_tmp)
sample = {'image': _img, 'label': _target}
if self.transform: # to do Data transformation or Data enhancement and convert torch
sample = self.transform(sample)
return sample
def encode_segmap(self, mask): # to change original image pixel value to 0-18 and 255 according class id
# Put all void classes to zero
for _voidc in self.void_classes:
mask[mask == _voidc] = self.ignore_index # no need class and unto set 255 (white)
for _validc in self.valid_classes:
mask[mask == _validc] = self.class_map[_validc] # 19 classes encode from 0 to 18
return mask
if __name__ == '__main__':
from dataloaders import custom_transforms as tr
from dataloaders.utils import decode_segmap
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt # to show image
composed_transforms_tr = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScale((0.5, 0.75)),
tr.RandomCrop((512, 1024)),
tr.RandomRotate(5),
tr.ToTensor()])
cityscapes_train = CityscapesSegmentation(split='train',
transform=composed_transforms_tr)
dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2)
for ii, sample in enumerate(dataloader):
for jj in range(sample["image"].size()[0]):
img = sample['image'].numpy() # from torch convert to numpy n x c x h x w
gt = sample['label'].numpy() # from torch convert to numpy n x c x h x w
tmp = np.array(gt[jj]).astype(np.uint8) # tmp.shape=c x h x w
tmp = np.squeeze(tmp, axis=0) # if c=1,tmp.shape=c x h x w; or tmp.shape=c x h x w
segmap = decode_segmap(tmp, dataset='cityscapes')
img_tmp = np.transpose(img[jj], axes=[1, 2, 0]).astype(np.uint8) # img_tmp=h x w x c
plt.figure()
plt.title('display')
plt.subplot(211)
plt.imshow(img_tmp)
plt.subplot(212)
plt.imshow(segmap)
if ii == 1:
break
plt.show(block=True)
转换的为:
composed_transforms_tr = transforms.Compose([
tr.RandomHorizontalFlip(),
tr.RandomScale((0.5, 0.75)),
tr.RandomCrop((512, 1024)),
tr.RandomRotate(5),
tr.ToTensor()])
上面关于图像变换或者说增强的实现代码如下:
上面的前四个变换都保持了原图和标签的type为PIL.PngImagePlugin.PngImageFile,这些图的像素数值大小和类型(uint8)不发生改变,结构也没有变化(原图为h x w x 3,标签图为h x w)
直到第五个也就是最后一个变化,对原图首先
class RandomHorizontalFlip(object):
def __call__(self, sample):
img = sample['image']
mask = sample['label']
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
return {'image': img,
'label': mask}
class RandomScale(object):
def __init__(self, limit):
self.limit = limit
def __call__(self, sample):
img = sample['image']
mask = sample['label']
assert img.size == mask.size
scale = random.uniform(self.limit[0], self.limit[1])
w = int(scale * img.size[0])
h = int(scale * img.size[1])
img, mask = img.resize((w, h), Image.BILINEAR), mask.resize((w, h), Image.NEAREST)
return {'image': img, 'label': mask}
class RandomCrop(object):
def __init__(self, size, padding=0):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size # h, w
self.padding = padding
def __call__(self, sample):
img, mask = sample['image'], sample['label']
if self.padding > 0:
img = ImageOps.expand(img, border=self.padding, fill=0)
mask = ImageOps.expand(mask, border=self.padding, fill=0)
assert img.size == mask.size
w, h = img.size
th, tw = self.size # target size
if w == tw and h == th:
return {'image': img,
'label': mask}
if w < tw or h < th:
img = img.resize((tw, th), Image.BILINEAR)
mask = mask.resize((tw, th), Image.NEAREST)
return {'image': img,
'label': mask}
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
img = img.crop((x1, y1, x1 + tw, y1 + th))
mask = mask.crop((x1, y1, x1 + tw, y1 + th))
return {'image': img,
'label': mask}
class RandomRotate(object):
def __init__(self, degree):
self.degree = degree
def __call__(self, sample):
img = sample['image']
mask = sample['label']
rotate_degree = random.random() * 2 * self.degree - self.degree
img = img.rotate(rotate_degree, Image.BILINEAR)
mask = mask.rotate(rotate_degree, Image.NEAREST)
return {'image': img,
'label': mask}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
img = np.array(sample['image']).astype(np.float32).transpose((2, 0, 1))
mask = np.expand_dims(np.array(sample['label']).astype(np.float32), -1).transpose((2, 0, 1))
mask[mask == 255] = 0 #
img = torch.from_numpy(img).float()
mask = torch.from_numpy(mask).float()
return {'image': img,
'label': mask}