回调函数使用
回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。你可以传递一个列表的回调函数(作为 callbacks
关键字参数)到 Sequential
或 Model
类型的 .fit()
方法。在训练时,相应的回调函数的方法就会被在各自的阶段被调用。
模型生成并保存以后每次只需要调用保存好的模型,不需要重新训练
Callback
keras.callbacks.Callback()
用来组建新的回调函数的抽象基类。
属性
- params: 字典。训练参数, (例如,verbosity, batch size, number of epochs...)。
- model:
keras.models.Model
的实例。 指代被训练模型。
被回调函数作为参数的 logs
字典,它会含有于当前批量或训练轮相关数据的键。
目前,Sequential
模型类的 .fit()
方法会在传入到回调函数的 logs
里面包含以下的数据:
- on_epoch_end: 包括
acc
和loss
的日志, 也可以选择性的包括val_loss
(如果在fit
中启用验证),和val_acc
(如果启用验证和监测精确值)。 - on_batch_begin: 包括
size
的日志,在当前批量内的样本数量。 - on_batch_end: 包括
loss
的日志,也可以选择性的包括acc
(如果启用监测精确值)。
ModelCheckpoint
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
在每个训练期之后保存模型。
filepath
可以包括命名格式选项,可以由 epoch
的值和 logs
的键(由 on_epoch_end
参数传递)来填充。
例如:如果 filepath
是 weights.{epoch:02d}-{val_loss:.2f}.hdf5
, 那么模型被保存的的文件名就会有训练轮数和验证损失。
参数
- filepath: 字符串,保存模型的路径。
- monitor: 被监测的数据。
- verbose: 详细信息模式,0 或者 1 。
- save_best_only: 如果
save_best_only=True
, 被监测数据的最佳模型就不会被覆盖。 - mode: {auto, min, max} 的其中之一。 如果
save_best_only=True
,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于val_acc
,模式就会是max
,而对于val_loss
,模式就需要是min
,等等。 在auto
模式中,方向会自动从被监测的数据的名字中判断出来。 - save_weights_only: 如果 True,那么只有模型的权重会被保存 (
model.save_weights(filepath)
), 否则的话,整个模型会被保存 (model.save(filepath)
)。 - period: 每个检查点之间的间隔(训练轮数)。
CSVLogger
keras.callbacks.CSVLogger(filename, separator=',', append=False)
把训练轮结果数据流到 csv 文件的回调函数。
支持所有可以被作为字符串表示的值,包括 1D 可迭代数据,例如,np.ndarray。
例子
csv_logger = CSVLogger('training.log')
model.fit(X_train, Y_train, callbacks=[csv_logger])
参数
- filename: csv 文件的文件名,例如 'run/log.csv'。
- separator: 用来隔离 csv 文件中元素的字符串。
- append: True:如果文件存在则增加(可以被用于继续训练)。False:覆盖存在的文件。
EarlyStopping
keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
当被监测的数量不再提升,则停止训练。
参数
- monitor: 被监测的数据。
- min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
- patience: 没有进步的训练轮数,在这之后训练就会被停止。
- verbose: 详细信息模式。
- mode: {auto, min, max} 其中之一。 在
min
模式中, 当被监测的数据停止下降,训练就会停止;在max
模式中,当被监测的数据停止上升,训练就会停止;在auto
模式中,方向会自动从被监测的数据的名字中判断出来。 - baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
- restore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。
ReduceLROnPlateau
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)
当标准评估停止提升时,降低学习速率。
当学习停止时,模型总是会受益于降低 2-10 倍的学习速率。 这个回调函数监测一个数据并且当这个数据在一定「有耐心」的训练轮之后还没有进步, 那么学习速率就会被降低。
例子
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
patience=5, min_lr=0.001)
model.fit(X_train, Y_train, callbacks=[reduce_lr])
参数
- monitor: 被监测的数据。
- factor: 学习速率被降低的因数。新的学习速率 = 学习速率 * 因数
- patience: 没有进步的训练轮数,在这之后训练速率会被降低。
- verbose: 整数。0:安静,1:更新信息。
- mode: {auto, min, max} 其中之一。如果是
min
模式,学习速率会被降低如果被监测的数据已经停止下降; 在max
模式,学习塑料会被降低如果被监测的数据已经停止上升; 在auto
模式,方向会被从被监测的数据中自动推断出来。 - min_delta: 对于测量新的最优化的阀值,只关注巨大的改变。
- cooldown: 在学习速率被降低之后,重新恢复正常操作之前等待的训练轮数量。
- min_lr: 学习速率的下边界。