文章目录
1. 为什么要有 Hook?
SessionRunHook
用来扩展那些将 session
封装起来的高级 API 的 session.run
的行为。
2. Hook 有什么用?
SessionRunHook
对于追踪训练过程、报告进度、实现提前停止等非常有用。
SessionRunHook
以观察者模式运行。SessionRunHook
的设计中有几个非常重要的时间点:
session
使用前session.run()
调用之前session.run()
调用之后session
关闭前
SessionRunHook
封装了一些可重用、可组合的计算,并且可以顺便完成 session.run()
的调用。利用 Hook,我们可以为 run()
调用添加任何的 ops或tensor/feeds;并且在 run()
调用完成后获得请求的输出。Hook 可以利用 hook.begin()
方法向图中添加 ops,但请注意:在 begin()
方法被调用后,计算图就 finalized 了。
3. TF 内置了哪些 Hook?
TensorFlow 中已经内置了一些 Hook:
StopAtStepHook
:根据 global_step 来停止训练。CheckpointSaverHook
:保存 checkpoint。LoggingTensorHook
:以日志的形式输出一个或多个 tensor 的值。NanTensorHook
:如果给定的Tensor
包含 Nan,就停止训练。SummarySaverHook
:保存 summaries 到一个 summary writer。
4. TF 怎么自定义 Hook?
上节,我们已经介绍了预制 Hook,使用其可以实现一些常见功能。如果这些 Hook 不能满足你的需求,那么自定义 Hook 是比较好的选择。
下面是自定义 Hook 的编写模板:
class ExampleHook(tf.train.SessionRunHook):
def begin(self):
# You can add ops to the graph here.
print('Starting the session.')
self.your_tensor = ...
def after_create_session(self, session, coord):
# When this is called, the graph is finalized and
# ops can no longer be added to the graph.
print('Session created.')
def before_run(self, run_context):
print('Before calling session.run().')
return SessionRunArgs(self.your_tensor)
def after_run(self, run_context, run_values): # run_values 为 sess.run 的结果
print('Done running one step. The value of my tensor: %s',
run_values.results)
if you-need-to-stop-loop:
run_context.request_stop()
def end(self, session):
print('Done with the session.')
5. 怎么使用 Hook?
在那些将 session
封装起来的高阶 API 中,我们可以使用 Hook 来扩展这些这些 API 的 session.run()
的行为。
首先,我们梳理一下将 session
封装起来的高阶 API 有哪些?这些 API 包括,但不限于:
tf.train.MonitoredTrainingSession
:tf.estimator.Estimator
:tf.contrib.slim
:
5.1 怎么在 MonitoredTrainingSession
中使用 Hook
with tf.train.MonitoredTrainingSession(hooks=your_hooks, ...) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(your_fetches)
5.2 怎么在 Estimator
中使用 Hook
在 tf.estimator.Estimator
的 train
、evaluate
、predict
方法中都可以使用 Hook。
下面是这些方法的 API:
# 训练
# 这里的 est 是一个 Estimator 实例
est.train(input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None)
# 评估
est.evaluate(input_fn,
steps=None,
hooks=None,
checkpoint_path=None,
name=None)
# 预测
est.predict(input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None,
yield_single_examples=True)
5.3 怎么在 slim
中使用 Hook
Slim 是 TensorFlow 中一个非常优秀的高阶 API,其可以极大地简化模型的构建、训练、评估。
未完待续。。。。
6. Hook 是怎么运作的?
通过自定义 Hook 的过程,我们了解到一个 Hook 包括 begin
、after_create_session
、before_run
、after_run
、end
五个方法。
下面的伪代码演示了 Hook 的运行过程:
# 伪代码
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested:
call hooks.before_run()
try:
results = sess.run(merged_fetches, feed_dict=merged_feeds)
except (errors.OutOfRangeError, StopIteration):
break
call hooks.after_run()
call hooks.end()
sess.close()
注意:如果 sess.run()
引发 OutOfRangeError
、StopIteration
或其它异常,那么 hooks.after_run()
和 hooks.end()
将不会被执行。
7. 内置 Hook 的研究
预制的 Hook 比较多,这里我们以 tf.train.StopAtStepHook
为例,来看看内置 Hook 是怎么编写的。
# tf.train.StopAtStepHook 的定义
class StopAtStepHook(tf.train.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, num_steps=None, last_step=None):
"""Initializes a `StopAtStepHook`.
This hook requests stop after either a number of steps have been
executed or a last step has been reached. Only one of the two options can be
specified.
if `num_steps` is specified, it indicates the number of steps to execute
after `begin()` is called. If instead `last_step` is specified, it
indicates the last step we want to execute, as passed to the `after_run()`
call.
Args:
num_steps: Number of steps to execute.
last_step: Step after which to stop.
Raises:
ValueError: If one of the arguments is invalid.
"""
if num_steps is None and last_step is None:
raise ValueError("One of num_steps or last_step must be specified.")
if num_steps is not None and last_step is not None:
raise ValueError("Only one of num_steps or last_step can be specified.")
self._num_steps = num_steps
self._last_step = last_step
def begin(self):
self._global_step_tensor = tf.train.get_or_create_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use StopAtStepHook.")
def after_create_session(self, session, coord):
if self._last_step is None:
global_step = session.run(self._global_step_tensor)
self._last_step = global_step + self._num_steps
def before_run(self, run_context): # pylint: disable=unused-argument
return tf.train.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
global_step = run_values.results + 1
if global_step >= self._last_step:
# Check latest global step to ensure that the targeted last step is
# reached. global_step read tensor is the value of global step
# before running the operation. We're not sure whether current session.run
# incremented the global_step or not. Here we're checking it.
step = run_context.session.run(self._global_step_tensor)
if step >= self._last_step:
run_context.request_stop()