Pytorch之数据加载以及处理
前言
汲取自pytorch-DATA LOADING AND PROCESSING TUTORIAL
这里主要介绍了数据集的处理,从类的构造角度阐述了如何自己打造各个函数,不过到最后还是给出了pytorch自带的包,给我们省了不少事
包的导入
Dataset class
torch.utils.data.Dataset 是一个表示数据集的抽象类
我们自定义的dataset需要继承Dataset并且重载以下的方法
- __len__ ,即可调用len(dataset)返回数据集的大小
- __getitem__ ,即可使用dataset[i]访问数据集
例子:
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
"""
Args:
csv_file (string): Path to the csv file with annotations.
root_dir (string): Directory with all the images.
transform (callable, optional): Optional transform to be applied
on a sample.
"""
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,
self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:].as_matrix()
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
Transforms
许多神经网络需要同样大小的图像,因此,我们需要写一些预处理的代码来。比如
- Rescale: 缩放图片
- RandomCrop: 随机裁剪图片
- ToTensor: 把numpy图片转换成torch图片
注意写类的时候要定义__call__,这样才方便当成一个函数调用
样例代码:
class Rescale(object):
"""Rescale the image in a sample to a given size.
Args:
output_size (tuple or int): Desired output size. If tuple, output is
matched to output_size. If int, smaller of image edges is matched
to output_size keeping aspect ratio the same.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
img = transform.resize(image, (new_h, new_w))
# h and w are swapped for landmarks because for images,
# x and y axes are axis 1 and 0 respectively
landmarks = landmarks * [new_w / w, new_h / h]
return {'image': img, 'landmarks': landmarks}
class RandomCrop(object):
"""Crop randomly the image in a sample.
Args:
output_size (tuple or int): Desired output size. If int, square crop
is made.
"""
def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h,
left: left + new_w]
landmarks = landmarks - [left, top]
return {'image': image, 'landmarks': landmarks}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
image, landmarks = sample['image'], sample['landmarks']
# swap color axis because
# numpy image: H x W x C
# torch image: C X H X W
image = image.transpose((2, 0, 1))
return {'image': torch.from_numpy(image),
'landmarks': torch.from_numpy(landmarks)}
Compose transforms
torchvision.transforms.Compose
是一个简单的可调用的类,可以组合一系列的transform操作
比如
composed = transforms.Compose([Rescale(256),
RandomCrop(224)])
再进行composed(image)就可对image进行操作
Iterating through the dataset
torch.utils.data.DataLoader
提供如下功能:
- Batching the data #对几对数据打包
- shuffling the data #随机打乱数据
- Load the data in parallel using multiprocessing workers #使用多进程手段加载数据
用法:
dataloader = DataLoader(transformed_dataset, batch_size=4, shuffle=True, num_workers=4)
torchvision
在本节中,我们已经学会如何使用dataset, transform, dataloader
torchvision
包提供了一些常用的dataset和transforms
例如:
import torch
from torchvision import transforms, datasets
data_transform = transforms.Compose([
# 这些都是torchvision.transforms自带的变换函数
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# torchvision.dataset自带的数据集
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
batch_size=4, shuffle=True,
num_workers=4)