文章目录
前言
detectron2中对读入的数据集格式有特定要求,比如说需要将voc和coco格式数据集转成dataset_dict的形式。这么干说可能你也听不明白,还是一步一步来吧。
1、DatasetCatalog和MetadataCatalog对象
本次分析代码主要在detectron2/data目录下。
如上图,dataset里面就是coco.py/voc.py等,samplers就是为了后续dataloader生成索引的。而transfoms就是图像增强部分。比较难以理解是catlog.py。里面定义了标题所说的两个类,乍一看源码让人摸不着头脑。我这里贴一下:
class _DatasetCatalog(UserDict):
def register(self, name, func):
assert callable(func), "You must register a function with `DatasetCatalog.register`!"
assert name not in self, "Dataset '{}' is already registered!".format(name)
self[name] = func
def get(self, name):
try:
f = self[name]
except KeyError as e:
raise KeyError(
"Dataset '{}' is not registered! Available datasets are: {}".format(
name, ", ".join(list(self.keys()))
)
) from e
return f()
很抽象,注意看下register方法,注册的是一个func函数。然后再get方法中,通过name获取上述的func并最终return f()。即执行了该函数。ok,现在找一下再哪里调用了register方法。
在dataset目录下,有个build.py函数,我这里粘贴下:
if __name__.endswith(".builtin"):
# Assume pre-defined datasets live in `./datasets`.
_root = os.getenv("DETECTRON2_DATASETS", "datasets")
register_all_coco(_root)
register_all_lvis(_root)
register_all_cityscapes(_root)
register_all_cityscapes_panoptic(_root)
register_all_pascal_voc(_root)
register_all_ade20k(_root)
register_all_mot(_root)
register_all_crowdhuman(_root)
这里看下voc数据集:
def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES):
DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names))
MetadataCatalog.get(name).set(
thing_classes=list(class_names), dirname=dirname, year=year, split=split
)
OK,现在知道了,DatasetCatalog注册了一个name,并注册了lambda函数,该函数用来获取数据集的标注信息。也就是说,DatasetCatalog基本通过register方法注册了大多数数据集的获取标注信息的函数。
在程序启动后,这些操作会在导包时完成,最终的DatasetCatalog的内容如下:
即一个数据集名称对应一个lambda function。
MetadataCatalog主要是存储一个数据集目录。此处不展开了,其实没细看我。
2、dataset构建
1. 通过DatasetCatalog读取数据集
在d2构建dataloader文章中,装饰器中途开小灶生成了一个dataset_dict。本文将详细说明生成的过程。
在data/build.py中,导入数据集信息通过下面代码:
def get_detection_dataset_dicts(
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None):
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] #(coco_2017_train)
其中,dataset_name = ‘coco_2017_train’,之后通过调用DatasetCatlog的get方法获取一个lambda的fun。如果你忘了,看下第一节。则经过这行代码dataset_dicts也就生成了{‘image_name’, bbox}。
2. mapper封装dataset_dict
在上述获得dataset_dicts之后,接下来在build.py代码逻辑中,就是该构建mapper,然后sampler的顺序。
def _train_loader_from_config(cfg, *, mapper=None, dataset=None, sampler=None):
if dataset is None:
dataset = get_detection_dataset_dicts( # 读取dataset_dicts
cfg.DATASETS.TRAIN,
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE
if cfg.MODEL.KEYPOINT_ON
else 0,
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None,
)
if mapper is None:
mapper = DatasetMapper(cfg, True) # 定义了一个mapper
if sampler is None: # 定义了一个sampler
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
logger = logging.getLogger(__name__)
logger.info("Using training sampler {}".format(sampler_name))
if sampler_name == "TrainingSampler":
sampler = TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))
return {
"dataset": dataset,
"sampler": sampler,
"mapper": mapper,
"total_batch_size": cfg.SOLVER.IMS_PER_BATCH,
"aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING,
"num_workers": cfg.DATALOADER.NUM_WORKERS,
}
这里贴下dataset_mapper.py中内容(代码内容比较多),作用主要是__call__方法,该函数作用见源码中注释:将dataset_dict变成d2模型可以接受的格式,在call方法中: 获取file_name --> 图像增强–>返回图像数据和annotation。
class DatasetMapper:
"""
The callable currently does the following:
1. Read the image from "file_name"
2. Applies cropping/geometric transforms to the image and annotations
3. Prepare data and annotations to Tensor and :class:`Instances`
"""
@configurable
def __init__(
self,
is_train: bool,
*,
augmentations: List[Union[T.Augmentation, T.Transform]],
image_format: str,
use_instance_mask: bool = False,
use_keypoint: bool = False,
instance_mask_format: str = "polygon",
keypoint_hflip_indices: Optional[np.ndarray] = None,
precomputed_proposal_topk: Optional[int] = None,
recompute_boxes: bool = False,
):
def __call__(self, dataset_dict):
...
3. 构建dataset
OK,有了上述的mapper和sampler,可以构建dataset的__getitem__了。调用接口在build.py中:
@configurable(from_config=_train_loader_from_config)
def build_detection_train_loader(
dataset, *, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0
):
if isinstance(dataset, list):
dataset = DatasetFromList(dataset, copy=False)
if mapper is not None:
dataset = MapDataset(dataset, mapper)
if sampler is None:
sampler = TrainingSampler(len(dataset))
assert isinstance(sampler, torch.utils.data.sampler.Sampler)
return build_batch_data_loader(
dataset,
sampler,
total_batch_size,
aspect_ratio_grouping=aspect_ratio_grouping,
num_workers=num_workers,
)
也就是DatasetFromList类和MapDataset类。这两个类在data/common.py文件中:
class MapDataset(data.Dataset):
"""
Map a function over the elements in a dataset.
Args:
dataset: a dataset where map function is applied.
map_func: a callable which maps the element in dataset. map_func is
responsible for error handling, when error happens, it needs to
return None so the MapDataset will randomly use other
elements from the dataset.
"""
def __init__(self, dataset, map_func):
self._dataset = dataset
self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
self._rng = random.Random(42)
self._fallback_candidates = set(range(len(dataset)))
def __len__(self):
return len(self._dataset)
def __getitem__(self, idx):
retry_count = 0
cur_idx = int(idx)
while True:
data = self._map_func(self._dataset[cur_idx])
if data is not None:
self._fallback_candidates.add(cur_idx)
return data
# _map_func fails for this idx, use a random new index from the pool
retry_count += 1
self._fallback_candidates.discard(cur_idx)
cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
if retry_count >= 3:
logger = logging.getLogger(__name__)
logger.warning(
"Failed to apply `_map_func` for idx: {}, retry count: {}".format(
idx, retry_count
)
)
最后,有了dataset和sampler,就可以构建dataloader了。
总结
对代码没做详细解释,主要是构建思路。后续会开构建模型和构建优化器等。