使用Pytorch构建一个新算法时,通常包含如下几步:
注册数据集:CustomDataset是MMDetection在原始的Dataset基础上的再次封装,其__getitem__()方法会根据训练和测试模式分别重定向到prepare_train_img()和prepare_test_img()函数。用户以继承CustomDataset类的方式构建自己的数据集时,需要重写load_annotations()和get_ann_info()函数,定义数据和标签的加载及遍历方式。完成数据集类的定义后,还需要使用DATASETS.register_module()进行模块注册。
注册模型:模型构建的方式和Pytorch类似,都是新建一个Module的子类然后重写forward()函数。唯一的区别在于MMDetection中需要继承BaseModule而不是Module,BaseModule是Module的子类,MMLab中的任何模型都必须继承此类。另外,MMDetection将一个完整的模型拆分为backbone、neck和head三部分进行管理,所以用户需要按照这种方式,将算法模型拆解成3个类,分别使用BACKBONES.register_module()、NECKS.register_module()和HEADS.register_module()完成模块注册。
构建配置文件:配置文件用于配置算法各个组件的运行参数,大体上可以包含四个部分:datasets、models、schedules和runtime。完成相应模块的定义和注册后,在配置文件中配置好相应的运行参数,然后MMDetection就会通过Registry类读取并解析配置文件,完成模块的实例化。另外,配置文件可以通过_base_字段实现继承功能,以提高代码复用率。
训练和验证:在完成各模块的代码实现、模块的注册、配置文件的编写后,就可以使用./tools/train.py和./tools/test.py对模型进行训练和验证,不需要用户编写额外的代码。
训练流程
按照数据流过程,训练流程可以简单总结为:
1.给定任何一个数据集,首先需要构建 Dataset 类,用于迭代输出数据
2.在迭代输出数据的时候需要通过数据 Pipeline 对数据进行各种处理,最典型的处理流是训练中的数据增强操作,测试中的数据预处理等等
3.通过 Sampler 采样器可以控制 Dataset 输出的数据顺序,最常用的是随机采样器 RandomSampler。由于 Dataset 中输出的图片大小不一样,为了尽可能减少后续组成 batch 时 pad 的像素个数,MMDetection 引入了分组采样器 GroupSampler 和 DistributedGroupSampler,相当于在 RandomSampler 基础上额外新增了根据图片宽高比进行 group 功能
4.将 Sampler 和 Dataset 都输入给 DataLoader,然后通过 DataLoader 输出已组成 batch 的数据,作为 Model 的输入
5.对于任何一个 Model,为了方便处理数据流以及分布式需求,MMDetection 引入了两个 Model 的上层封装:单机版本 MMDataParallel、分布式(单机多卡或多机多卡)版本 MMDistributedDataParallel
Model 运行后会输出 loss 以及其他一些信息,会通过 logger 进行保存或者可视化
6.为了更好地解耦, 方便地获取各个组件之间依赖和灵活扩展,MMDetection 引入了 Runner 类进行全生命周期管理,并且通过 Hook 方便的获取、修改和拦截任何生命周期数据流,扩展非常便捷
Config类文件
MMDetection使用MMCV库中Config类完成对配置文件的解析。Config 类用于操作配置文件,它支持从多种文件格式中加载配置,包括python, json和yaml。 它提供了类似字典对象的接口来获取和设置值。
Config实现配置文件到模型,需要走两步:
导入.py配置文件到dict,2)通过dict构造class
读取配置文件
一般使用Config.fromfile(filename)来读取配置文件(也可以直接传入一个dict),返回一个Config类:
from mmcv import Config
cfg = Config.fromfile('../configs/test_config.py')
fromfile()函数源码如下,其核心函数是_file2dict()。_file2dict()会根据文本顺序,按照key = value的格式解析配置文件,得到一个名为cfg_dict的字典,如果存在_base_字段,还会对_base_包含的每个文件路径再调用一次_file2dict()函数,将文件中包含的配置参数加入到cfg_dict中,实现配置文件的继承功能。需要注意的是,_file2dict()内部会对_base_中不同文件包含的键值进行校验,不同base文件中不允许出现重复的键值,否则MMCV不知道以哪个base文件为准。
@staticmethod
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
# import_modules_from_strings()是根据字符串列表对应的模块
if import_custom_modules and cfg_dict.get('custom_imports', None):
import_modules_from_strings(**cfg_dict['custom_imports'])
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
另外有两点需要补充一下,其一是构造Config对象的时候,会将python的dict数据类型转换为ConfigDict类型进行处理。ConfigDict是第三方库addict中Dict的子类,因为python原生的dict类型不支持.属性的访问方式,特别是dict内部嵌套了多层dict的时候,如果按照key的访问方式,代码写起来非常低效,而Dict类通过重写__getattr__()的方式实现了.属性的访问方式。所以继承了Dict的ConfigDict也支持使用.属性的方式访问字典中的各个成员值。
from mmcv import ConfigDict
model = ConfigDict(dict(backbone=dict(type='ResNet', depth=50)))
print(model.backbone.type) # 输出 'ResNet'
修改配置参数
知道MMCV解析配置文件的内部逻辑后,如何修改配置参数的值自然也清楚了。因为_file2dict()是根据文本顺序构建字典,所以后写的键值可以覆盖原来的值 (如果变量类型是list,会将list进行全部替换,无法实现某一个item的修改)。以修改优化器为例,原来的继承的优化器是SGD,学习率为0.02:
# 原来继承的优化器
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
现在想要将由_base_继承的学习率调整为0.001,可以直接在当前配置文件中增加一行:
# 修改学习率
optimizer = dict(lr=0.001)
这样只会修改optimizer键值中的lr参数,其他参数不受影响,当前优化器就配置就变成了:
# 修改学习率后的SGD
optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0001)
如果想要现在想要换一个新的优化器,但两个优化器的参数不兼容,需要删掉原来的键值,用一组全新的键值代替,这时可以通过配置_delete_=True来实现:
# 将原来的SGD替换成AdamW
optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, weight_decay=0.0001)
然后就完成了优化器的替换,当前配置文件的优化器参数如下:
# 当前优化器变为AdamW
optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001)
查看配置文件
调用Config.fromfile()后会返回一个Config对象,其中除了上面提到的解析配置文件得到的ConfigDict类型的_cfg_dict外,还包含了text和pretty_text这两个成员变量。
text存储的是各个配置文件(包含_base_中继承的文件)中的原始文本信息,会标识配置文件的路径。
pretty_text是_cfg_dict字典内容的格式化文本,MMCV内部是借助Google的YAPF库来对字典对象进行格式化,使其输出符合人们的阅读习惯,直接print(cfg.pretty_text)即可查看完整配置文件信息,和MMDetection的mmdetection/tools/misc/print_config.py的效果是一样的。
另外,还可以通过cfg.dump(filepath)将pretty_text存储到文件中方便查看。
Registry类文件(MMCV 核心组件分析(五): Registry - 知乎 (zhihu.com))
Registry 功能和用法
在 OpenMMLab 中,Registry 类可以提供一种完全相似的对外装饰函数来管理构建不同的组件,例如 backbones、head 和 necks 等等,Registry 类内部其实维护的是一个全局 key-value 对。通过 Registry 类,用户可以通过字符串方式实例化任何想要的模块。
简单来说,注册机制就是维护几张查询表,key是模块的名称,value是模块的句柄,每张查询表都管理一批功能相似的不同模块。我们每新建一个模块,都要根据模块实现的功能将对应的key-value查询对保存到对应的查询表中,这个保存的过程就称为“注册”。当我们想要调用某个模块时,只需要根据模块名称从查询表中找到对应的模块句柄,然后就能完成模块初始化或方法调用等操作。MMCV通过Registry类来实现字符串(key)到类(value)的映射。
模块的注册通过Registry的成员函数register_module()来实现,register_module()内部又会调用另一个私有函数_register_module(),模块注册的核心功能其实是在_register_module()中实现的。核心代码也很简单,就是将传入的module_name和module_class保存到字典self._module_dict中。
#class Register中主要使用的函数,用于注册
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
>>> backbones = Registry('backbone')
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class or function to be registered.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(module=module, module_name=name, force=force)
return module
# use it as a decorator: @x.register_module()
def _register(module):
self._register_module(module=module, module_name=name, force=force)
return module
return _register
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
值得注意的是:上述功能实现的核心是 python 中装饰器用法。
在我们通过字符串获取到一个模块的句柄后,可以通过self.build_func函数句柄来实例化这个模块。build_func可以人为指定,也可以从父类继承,一般来说都是默认使用build_from_cfg()函数,即使用配置参数cfg来初始化该模块。配置参数cfg是一个字典,里面的type字段是模块名称的字符串,其他字段则对应模块构造函数的输入参数。
def build_from_cfg(cfg, registry, default_args=None):
args = cfg.copy()
# 将cfg以外的外部传入参数也合并到args中
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
# 获取模块名称
obj_type = args.pop('type')
if isinstance(obj_type, str):
# get函数返回registry._module_dict中obj_type对应的模块句柄
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(f'{obj_type} is not in the {registry.name} registry')
elif inspect.isclass(obj_type):
# type值是模块本身
obj_cls = obj_type
else:
raise TypeError(f'type must be a str or valid type, but got {type(obj_type)}')
# 模块初始化, 返回模块实例
try:
return obj_cls(**args)
except Exception as e:
raise type(e)(f'{obj_cls.__name__}: {e}')
考虑到registry参数需要指向当前注册器本身,我们一般是调用Registry类的build()方法而不是self.build_func。
def build(self, *args, **kwargs):
return self.build_func(*args, **kwargs, registry=self)
下面是一个小例子,模拟了网络模型的注册和调用过程。注意一下,我们打印Registry对象时,实际上打印的是self._module_dict中的values。
# 实例化一个注册器用来管理模型
MODELS = Registry('myModels')
# 方式1: 在类的创建过程中, 使用函数装饰器进行注册(推荐)
@MODELS.register_module()
class ResNet(object):
def __init__(self, depth):
self.depth = depth
print('Initialize ResNet{}'.format(depth))
# 方式2: 完成类的创建后, 再显式调用register_module进行注册(不推荐)
class FPN(object):
def __init__(self, in_channel):
self.in_channel= in_channel
print('Initialize FPN{}'.format(in_channel))
MODELS.register_module(name='FPN', module=FPN)
print(MODELS)
""" 打印结果为:
Registry(name=myModels, items={'ResNet': <class '__main__.ResNet'>, 'FPN': <class '__main__.FPN'>})
"""
# 配置参数, 一般cfg从配置文件中获取
backbone_cfg = dict(type='ResNet', depth=101)
neck_cfg = dict(type='FPN', in_channel=256)
# 实例化模型(将配置参数传给模型的构造函数), 得到实例化对象
my_backbone = MODELS.build(backbone_cfg)
my_neck = MODELS.build(neck_cfg)
print(my_backbone, my_neck)
""" 打印结果为:
Initialize ResNet101
Initialize FPN256
<__main__.ResNet object at 0x000001E68E99E198> <__main__.FPN object at 0x000001E695044B38>
"""
Runner和Hook机制
Runner又称执行器,负责模型训练过程的调度,主要目的是让用户使用更少的代码以及灵活可配置的方式开启训练。换句话说,MMCV将整个训练过程封装起来了,并使用Runner进行管理和配置。高度封装虽然减少了代码量,但如何对内部流程进行自定义的修改(比如动态调整学习率等)?这时就需要用到Hook机制。
Hook是能够改变程序执行流程的一种技术统称。通俗的说,Hook可以理解为一种触发器,在程序预定义的位置执行预定义的函数。MMCV已经在几个常用的位置预留了接口函数(称为回调函数),如下图所示。MMCV已经实现了一些常用的Hook函数,同时用户也可以增加自己的Hook函数,非常方便。当程序执行到指定位置时,就会进入到回调函数中,执行相应的功能,执行结束后再接着执行主流程。
对应训练代码上:
# 开始运行时调用
before_run()
while self.epoch < self._max_epochs:
# 开始 epoch 迭代前调用
before_train_epoch()
for i, data_batch in enumerate(self.data_loader):
# 开始 iter 迭代前调用
before_train_iter()
self.model.train_step()
# 经过一次迭代后调用
after_train_iter()
# 经过一个 epoch 迭代后调用
after_train_epoch()
# 运行完成前调用
after_run()
Runner封装了OpenMMLab体系下各个框架的训练和验证流程,负责管理训练/验证过程的整个生命周期;通过预定义的回调函数,用户可以插入定制化Hook,实现各种各样定制化的需求。
Runner类
Runner分为EpochBasedRunner和IterBasedRunner,顾名思义,前者以epoch的方式管理流程,后者以iter的方式管理流程,它们都是BaseRunner的子类。BaseRunner的任何子类都需要实现run()、train()、val()和save_checkpoint()四个方法,这也是Runner的核心方法。这里以EpochBasedRunner为例对上述四个函数进行分析,为了使代码结构看起来更清晰,删去了和核心功能无关的代码。
构造函数
EpochBasedRunner和IterBasedRunner都是BaseRunner的子类,继承了BaseRunner的构造函数。runner默认调用model类中的train_step()和val_step()进行训练和验证,如果指定了batch_processor,则会调用batch_processor对data_loader中的数据进行处理。
class BaseRunner(metaclass=ABCMeta):
def __init__(self,
model, # [torch.nn.Module] 要运行的模型
batch_processor=None, # 过时用法, 通过实现模型中的train_step()和val_step()方法替代
optimizer=None, # [torch.optim.Optimizer] 优化器, 可以是一个也可以是一组通过dict配置的优化器
work_dir=None, # [str] 保存检查点和Log的目录
logger=None, # [logging.Logger] 训练中使用的日志记录器
meta=None, # [dict] 一些信息, 这些信息会在logger hook中记录
max_iters=None, # [int] 训练epoch数
max_epochs=None): # [int] 训练迭代次数
run()函数
run()是runner类的主调函数,会根据workflow指定的工作流,对data_loaders中的数据进行处理。目前MMCV支持训练和验证两种工作流,对于EpochBasedRunner而言,workflow配置为[('train', 2),('val', 1)]表示先训练2个epoch,然后验证一个epoch;[('train', 1)]表示只进行训练,不进行验证。如果是IterBasedRunner,[('train', 2),('val', 1)]则表示先训练2个iter,然后验证一个iter。
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
# 根据工作流确定当前是运行train()还是val(), getattr返回对应的函数句柄
epoch_runner = getattr(self, mode)
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
# 运行train()或val()
epoch_runner(data_loaders[i], **kwargs)
train()和val()函数
train()和val()函数循环调用run_iter()完成一个epoch流程。函数开头的self.model.train()和self.model.eval()实际上调用的是torch.nn.module.Module的成员函数,将当前模块设置为训练模式或验证模式,两种不同模式下batchnorm、dropout等层的操作会有区别。然后由于测试过程不需要梯度回传,所以val函数加了一个装饰器@torch.no_grad()。
def train(self, data_loader, **kwargs):
# 将模块设置为训练模式
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
for i, data_batch in enumerate(self.data_loader):
self.run_iter(data_batch, train_mode=True, **kwargs)
self._iter += 1
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
# 将模块设置为验证模式
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
for i, data_batch in enumerate(self.data_loader):
self.run_iter(data_batch, train_mode=False)
train()和val()的核心函数是run_iter(),根据train_mode参数调用model.train_step()或model.val_step(),这两个函数最终会执行我们自己模型的forward()函数,返回loss值。
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
self.outputs = outputs
save_checkpoint()函数
save_checkpoint()函数调用torch.save将检查点以下列格式保存。
checkpoint = {
'meta': dict(), # 环境信息(比如epoch_num, iter_num)
'state_dict': dict(), # 模型的state_dict()
'optimizer': dict()) # 优化器的state_dict()
}
Hook类
MMCV在./mmcv/runner/hooks/hook.py中定义了Hook的基类以及Hook的注册器HOOKS。作为基类,Hook本身没有实现具体的函数,只是提供了before_run、after_run等6个接口函数,其他所有的Hooks都通过继承Hook类并重写相应的函数完整指定功能。
from mmcv.utils import Registry
HOOKS = Registry('hook')
class Hook:
def before_run(self, runner):
pass
def after_run(self, runner):
pass
def before_epoch(self, runner):
pass
def after_epoch(self, runner):
pass
def before_iter(self, runner):
pass
def after_iter(self, runner):
pass
MMCV已经实现了部分常用的Hooks,如下图所示。默认Hook不需要用户自行注册,通过配置文件配置对应的参数即可;定制Hook则需要用户手动注册进去。
Hook也是一个模块,使用时需要定义、注册、调用3个步骤。
定义
MMCV实现的Hook都在./mmcv/runner/hooks目录下,这里以CheckpointHook为例介绍一下怎么新建一个Hook。
首先从hook.py中导入注册器HOOKS以及基类Hook。然后新建一个名为CheckpointHook类继承Hook基类,由于Hook基类没有定义构造函数,这里首先必须自己定义__init__函数,然后根据Hook需要实现的功能,重写Hook基类中的一种或几种方法。比如MMCV会在每次训练开始前打印checkpoint的保存路径,会在每次循环结束后或每个epoch执行完成后保存checkpoint,因此CheckpointHook类重写了before_run、after_train_iter和after_train_epoch这3个方法。
from .hook import HOOKS, Hook
@HOOKS.register_module()
class CheckpointHook(Hook):
def __init__(self,
interval=-1,
by_epoch=True,
save_optimizer=True,
out_dir=None,
max_keep_ckpts=-1,
save_last=True,
sync_buffer=False,
file_client_args=None,
**kwargs):
...
def before_run(self, runner):
...
def after_train_iter(self, runner):
...
def after_train_epoch(self, runner):
...
注册
对于MMCV的默认Hook,在执行runner.run()前会调用BaseRunner类中的register_training_hooks方法进行注册:
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
"""Register default and custom hooks for training.
Default and custom hooks include:
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
具体到单个注册函数,比如register_checkpoint_hook(),hook作为一个模块,还是使用build_from_cfg进行实例获取,然后调用BaseRunner类的register_hook()进行注册,这样所有Hook实例就都被纳入到runner中的一个list中。
def register_checkpoint_hook(self, checkpoint_config):
hook = mmcv.build_from_cfg(checkpoint_config, HOOKS)
self.register_hook(hook, priority='NORMAL')
def register_hook(self, hook, priority='NORMAL'):
priority = get_priority(priority)
hook.priority = priority
# 按照priority大小插入当前hook列表
inserted = False
for i in range(len(self._hooks) - 1, -1, -1):
if priority >= self._hooks[i].priority:
self._hooks.insert(i + 1, hook)
inserted = True
break
if not inserted:
self._hooks.insert(0, hook)
调用
在runner执行过程中,会在特定的程序位点通过call_hook()函数调用相应的Hook。
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
前面调用register_hook()注册Hook的时候,会根据优先级将Hook加入到self._hooks这个列表中,在执行call_hook()时候,使用for循环就可以很简单的实现按照优先级依次调用指定的Hook了。
def call_hook(self, fn_name):
for hook in self._hooks:
getattr(hook, fn_name)(self)
reference
(72条消息) MMDetection框架入门教程(完全版)_Maples丶丶的博客-CSDN博客_mmdetection
轻松掌握 MMDetection 整体构建流程(二) - 知乎 (zhihu.com)
(82条消息) MMDetection框架入门教程(三):配置文件详细解析_Maples丶丶的博客-CSDN博客_mmdetection _base_
(87条消息) MMDetection框架入门教程(五):Runner和Hook详细解析_mmdetection的runner_Maples丶丶的博客-CSDN博客