我们知道,机器学习模型的效果好坏很大程度上取决于超参的选取。人肉调参需要依赖经验与直觉,且花费大量精力。PBT(Population based training)是DeepMind在论文《Population Based Training of Neural Networks》中提出的一种异步的自动超参数调节优化方法。以往的自动调节超参方法可分为两类:parallel search和sequential optimization。前者并行执行很多不同超参的优化任务,优点是可以并行利用计算资源更快找到最优解;后者需要利用之前的信息来进行下一步的超参优化,因此只能串行执行,但一般能得到更好的解。PBT完美地结合两种方法,兼具两者优点。它被应用于一些领域取得了不错的效果。如DeepMind的论文《Human-level performance in first-person multiplayer games with population-based deep reinforcement learning》将之用于第一人称多人游戏使AI达到人类水平。还有今年UC Berkeley的论文《Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules》中用PBT来自动学习data augmentation策略,在几个benchmark上达到了不错的精度。另外,最近自动驾驶公司Waymo也称将PBT应用于识别任务,与手工调参相比可以提高精度和加快训练速度。
PBT开局与parallel search类似,会并行训练一批随机初始化的模型。过程中它会周期性地将表现好的模型替换表现不好的模型(exploitation),同时再加上随机扰动(主要是为了exploration)。PBT与其它方法的一个重要不同是它在训练的过程中对超参进行调节,因此可以更快地发现超参和优异的schedule。论文《Population Based Training of Neural Networks》中的示意图非常清楚地示意了整个过程,及与其它方法的区别:
PBT是一种很通用的方法,可以用于很多场景,其一般套路如下:
- Step:对模型训练一步。至于一步是一次iteration还是一个epoch还是其它可以根据需要指定。
- Eval:在验证集上做评估。
- Ready: 选取群体中的一个模型来进行下面的exploit和explore操作(即perturbation)。这个模型一般是上次做过该操作后经过指定的时间(或迭代次数等)。
- Exploit: 将那些经过评估比较烂的模型用那些比较牛叉的模型替代。
- Explore: 对上一步产生的复制体模型加随机扰动,如加上随机值或重采样。
Ray中实现了PBT算法。Ray中关于PBT有三个example:一个是learning rate搜索pbt_example.py,另一个是强化学习算法PPO的超参数搜索pbt_ppo_example.py。还有一个是pbt_tune_cifar10_with_keras.py。我们来看下最简单的pbt_example.py
。其中的PBTBenchmarkExample
类继承自Trainable
类,它是一个toy的模拟环境,假设在模型训练过程中最优的learning rate是变化的,是accuracy的函数。目标是找到learning rate的schedule。它的核心函数是_train()
,这里会模拟最优的learning rate。
然后看主函数,首先通过ray.init()
初始化ray,然后创建PopulationBasedTraining
对象,接着通过run()函数开始超参搜索过程。
pbt = PopulationBasedTraining(
time_attr="training_iteration",
metric="mean_accuracy",
mode="max",
perturbation_interval=20,
hyperparam_mutations={
# distribution for resampling
"lr": lambda: random.uniform(0.0001, 0.02),
# allow perturbations within this set of categorical values
"some_other_factor": [1, 2],
})
run(
PBTBenchmarkExample,
name="pbt_test",
scheduler=pbt,
reuse_actors=True,
verbose=False,
**{
"stop": {
"training_iteration": 2000,
},
"num_samples": 4,
"config": {
"lr": 0.0001,
# note: this parameter is perturbed but has no effect on
# the model training in this example
"some_other_factor": 1,
},
})
先看第一步,PopulationBasedTraining
的实现在python/ray/tune/schedulers/pbt.py
中。它继承自FIFOScheduler
类。构造函数中几个主要参数:
time_attr
: 用于定义训练时长的测度,要求单调递增,比如training_iteration
。metric
: 训练结果衡量目标。mode
: 上面metric属性是越高越好,还是越低越好。perturbation_interval
: 模型会以time_attr
为间隔来进行perturbation。hyperparam_mutations
: 需要变异的超参。它是一个dict,对于每个key对应list或者function。如果没设这个,就需要在custom_explore_fn
中指定。quantile_fraction
: 决定按多大比例将表现好的头部模型克隆到尾部模型。resample_probability
: 当对超参进行exploration时从原分布中重新采样的概率,否则会根据现有的值调整。custom_explore_fn
: 自定义的exploration函数。
第二步中run()
函数实现在ray/python/ray/tune/tune.py
中:
def run(run_or_experiment, name=None, ...):
trial_executor = traial_executor or RayTrialExecutor(...)
experiment = run_or_experiment
if not isinstance(run_or_experiment, Experiment):
if not isinstance(run_or_experiment, Experiment):
experiment = Experiment(...)
...
runner = TrialRunner(
search_alg=search_alg or BasicVariantGenerator(),
scheduler=scheduler or FIFOScheduler(),
local_checkpoint_dir=experiment.checkpoint_dir,
remote_checkpoint_dir=experiment.remote_checkpoint_dir,
sync_to_cloud=sync_to_cloud,
checkpoint_period=global_checkpoint_period,
resume=resume,
launch_web_server=with_server,
server_port=server_port,
verbose=bool(verbose > 1),
trial_executor=trial_executor)
runner.add_experiment(experiment)
...
while not runner.is_finished():
runner.step()
...
wait_for_sync()
...
return ExperimentAnalysis(runner.checkpoint_file, trials=trials)
第一个参数run_or_experiment
是要训练的目标任务,参数scheduler
就是上面创建的PopulationBasedTraining
,负责超参搜索时的调度。
其中几个关键类关系如下图:
SearchAlgorithm
的实现类BasicVariantGenerator
会根据给定的Experiment
产生参数变体。每个待训练的参数变体会创建相应的Trial
对象。Trial
有PENDING, RUNNING, PAUSED, TERMINATED, ERROR几种状态。它会开始于PENDING状态,开始训练后转为RUNNING状态,出错了就到ERROR状态,成功的话就是TERMINATED状态。训练中还可能被TrialScheduler
暂停(转入PAUSED状态)并释放资源。
TrialRunner
是最核心的数据结构,它管理一系列的Trial
对象,并且执行一个事件循环,将这些任务通过TrialExecutor
的实现类RayTrialExecutor
提交到Ray cluster运行。RayTrialExecutor
会负责资源的管理。这里通过Ray分布执行的主要是Trainable
的实现类(上例中就是PBTBenchmarkExample
)中的_train()
函数。RayTrialExecutor
对象中的_running
维护了正在运行的Trial
。在循环中,TrialRunner
会通过TrialScheduler
的实现类PopulationBasedTraining
来进行调度。它的choose_trial_to_run()
函数从trial_runner
的queue中拿出状态为PENDING或者PAUSED的trial,并且选取离上次做perturbation最久的一个保证尽可能公平。
run
函数主要做以下几步:
- 创建
RayTrailExecutor
对象(如果没有传入trial_executor
的话)。 - 如果目标任务不是以
Experiment
对象形式给出,会按照给定的其它参数构建Experiment
对象。 - 创建
TrialRunner
对象,它基于Ray来调度事件循环。- 创建搜索算法对象(如果没给),默认为
BasicVariantGenerator
(实现在basic_variant.py
)。它主要用于产生新的参数变体。 - 创建执行实验的调度器(如果没给),默认为
FIFOScheduler
。上例中给定了PopulationBasedTraining
,所以这里就不需要创建了。 - 创建
TrialRunner
对象(实现在trial_runner.py
)。并上面创建的Experiment
对象通过add_experiment()
函数加到TrialRunner
对象中。
- 创建搜索算法对象(如果没给),默认为
- 进入主循环,通过
TrialRunner
的is_finished()
函数判断是否结束。如果没有,就调用TrialRunner
的step()
函数执行一步。step()
函数的主要工作下面再细说。 - 收尾工作。如通过
wait_for_sync()
函数同步远端目标,记录没有正常结束的trial,返回分析信息。
其中比较关键的是step()
函数,其主要流程如下:
当一个Trial
训练结束返回结果时,TrialRunner
会调用PopulationBasedTraining
的on_trial_result()
函数。这里就是PBT的精华了。结合文章开关的PBT一般套路,主要步骤如下:
- 如果离上次pertubation的时间还没到指定间隔,则返回让该
Trial
继续训练。 - 调用
_quantiles()
函数按设定的比例__quantile_fraction
得到所有Trial
中表现好的头部和表现不好的尾部。 - 如果当前trial是比较牛的那一批,那赶紧存成checkpoint,等着被其它trial克隆学习。
- 如果很不幸地,当前trial属于比较差的那一批,那就从牛的那批中随机挑一个(为
trial_to_clone
),然后调用_exploit()
函数。该函数会调用explore()
函数对trial_to_clone
进行扰动,然后将它的参数设置和checkpoint设置到当前trial。这样,当前trial就“洗心革面”,重新出发了。 - 如果
TrialRunner
中有PENDING和PAUSED状态的trial,则请求暂停当前trial,让出资源。否则的话就继续训练着。
最后,总结下主要模块间的大体流程: