Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:
1.代码
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.keras import backend as K
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
args, _ = parser.parse_known_args()
def get_lr_scheduler(args):
lr_scheduler = MultiStepLR(args=args)
return lr_scheduler
class MultiStepLR(Callback):
"""Learning rate scheduler.
Arguments:
args: parser_setting
verbose: int. 0: quiet, 1: update messages.
"""
def __init__(self, args, verbose=0):
super(MultiStepLR, self).__init__()
self.args = args
self.steps = args.lr_decay_epochs
self.factor = args.lr_decay_factor
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
if self.verbose > 0:
print('\nEpoch %05d: MultiStepLR reducing learning '
'rate to %s.' % (epoch + 1, lr))
def schedule(self, epoch):
lr = K.get_value(self.model.optimizer.lr)
for i in range(len(self.steps)):
if epoch == self.steps[i]:
lr = lr * self.factor
return lr
2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)
callbacks = []
lr_scheduler = get_lr_scheduler(args=args)
callbacks.append(lr_scheduler)
...
model.fit_generator(train_generator,
steps_per_epoch=train_generator.samples // args.batch_size,
validation_data=test_generator,
validation_steps=test_generator.samples // args.batch_size,
workers=args.num_workers,
callbacks=callbacks, # 你的callbacks, 包含了lr_scheduler
epochs=args.epochs,
)
大家可以拿去用~