torch.utils.data
Pytorch读取训练集需要用到torch.utils.data类,data类包括13个成员,主要用到的2个:
- class
torch.utils.data.
Dataset
- class
torch.utils.data.
DataLoader
(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
1. class torch.utils.data.
Dataset
An abstract class representing a Dataset.
# 一个用来表示数据集的抽象类
All other datasets should subclass it. All subclasses should override
__len__
, that provides the size of the dataset, and__getitem__
, supporting integer indexing in range from 0 to len(self) exclusive.# 其他所有的数据集都应该是这个类的子类,并且需要重载
__len__
和__getitem__
__len__提供数据集的大小;
__getitem__提供数据集的索引
2. class torch.utils.data.
DataLoader
(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
Data loader. Combines a dataset and a sampler, and provides single- or multi-process iterators over the dataset.
数据加载器,包括数据集和数据提取策略。
Parameters:
- dataset (Dataset) – dataset from which to load the data.加载的数据集
- batch_size (int, optional) – how many samples per batch to load (default: 1).
- shuffle (bool, optional) – set to
True
to have the data reshuffled at every epoch (default: False).- sampler (Sampler, optional) – defines the strategy to draw samples from the dataset. If specified,
shuffle
must be False.- batch_sampler (Sampler, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
- num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)指定多少子过程用于加载数据,0表示只在主过程加载。
- collate_fn (callable, optional) – merges a list of samples to form a mini-batch.
- pin_memory (bool, optional) – If
True
, the data loader will copy tensors into CUDA pinned memory before returning them.是否在返回前将Tensors保存进CUDA- drop_last (bool, optional) – set to
True
to drop the last incomplete batch, if the dataset size is not divisible by the batch size. IfFalse
and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)- timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
- worker_init_fn (callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in
[0, num_workers - 1]
) as input, after seeding and before data loading. (default: None)
例子
首先,定义一个新的class ImageList,按照要求继承于Dataset
import torch.utils.data as data
from PIL import Image
import os
import os.path
def default_loader(path):
img = Image.open(path).convert('L')
return img
def default_list_reader(fileList):
imgList = []
with open(fileList, 'r') as file:
for line in file.readlines():
imgPath, label = line.strip().split(' ')
imgList.append((imgPath, int(label)))
return imgList
class ImageList(data.Dataset):
def __init__(self, root, fileList, transform=None, list_reader=default_list_reader, loader=default_loader):
self.root = root
self.imgList = list_reader(fileList)
self.transform = transform
self.loader = loader
def __getitem__(self, index):
imgPath, target = self.imgList[index]
img = self.loader(os.path.join(self.root, imgPath))
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgList)
然后,使用ImageList生成一个自己的Dataset,其中ImageList(...)是用来生成一个指定的Dataset
#load image
train_loader = torch.utils.data.DataLoader(
ImageList(root=root_path, fileList=train_list,
transform=transforms.Compose([
transforms.Grayscale(),
transforms.RandomCrop(128),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
#transforms.Normalize([255.0],[0])
])),
batch_size=batch_size, shuffle=True,
num_workers=workers, pin_memory=True)