经查看官方文档将常用回调函数做以下小结,目的是了解每个回调函数的作用与参数用法。
上图是tf2.0的全部回调函数,在这里介绍常用的4个回调函数:EarlyStopping,tensorboard,ModelCheckpoint,history。
1、tf.keras.callbacks.EarlyStopping
目的/作用:当监控的值停止变化时,提前结束训练。
定义:
tf.keras.callbacks.EarlyStopping(
monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
baseline=None, restore_best_weights=False
)
由上面的代码段可以得知,当未自己手动设置monitor时,默认监控的是验证集的loss(val_loss)。
常用参数介绍:
monitor:监控的值。
min_delta:监视值的最小变化,即,绝对变化小于min_delta的情况,将视为没有变化。
patience:在多少个epoch,监控的值没有变化后,将停止训练。(也就是连续多少个epoch,监控值的绝对变化小于min_delta)
示例:
callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
# This callback will stop the training when there is no improvement in
# the validation loss for three consecutive epochs.
model.fit(data, labels, epochs=100, callbacks=[callback],
validation_data=(val_data, val_labels))
2、tf.keras.callbacks.TensorBoard
作用:tensorflow的可视化工具
定义:
tf.keras.callbacks.TensorBoard(
log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
update_freq='epoch', profile_batch=2, embeddings_freq=0,
embeddings_metadata=None, **kwargs
)
常用参数:
log_dir:将TensorBoard解析的日志文件保存到的目录路径。
其余用到再补充
示例:
logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
os.mkdir(logdir)
callbacks = [
keras.callbacks.TensorBoard(logdir),]
history = model.fit(x_train_scaled, y_train, epochs=100,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
tensorboard显示:
扫描二维码关注公众号,回复:
11176180 查看本文章
3、tf.keras.callbacks.ModelCheckpoint
作用:在每一次epoch后保存模型
定义:
tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)
常用参数:
filepath:字符串,保存模型文件的路径。
示例:
logdir = os.path.join("callbacks")
output_model_file = os.path.join(logdir,
"fashion_mnist_model.h5")
callbacks = [
keras.callbacks.ModelCheckpoint(output_model_file,
save_best_only = True),#保存最好的模型,默认保存最近的
]
history = model.fit(x_train_scaled, y_train, epochs=100,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)
4、tf.keras.callbacks.History
这个回调函数会自动应用到每一个keras模型,History对象通过模型的fit方法得到返回。
history = model.fit(x_train_scaled, y_train, epochs=100,
validation_data=(x_valid_scaled, y_valid),
callbacks = callbacks)