TensorFlow 之 SessionRunHook

1. 为什么要有 Hook?

SessionRunHook 用来扩展那些将 session 封装起来的高级 API 的 session.run 的行为。

2. Hook 有什么用?

SessionRunHook 对于追踪训练过程、报告进度、实现提前停止等非常有用。

SessionRunHook 以观察者模式运行。SessionRunHook 的设计中有几个非常重要的时间点:

  • session 使用前
  • session.run() 调用之前
  • session.run() 调用之后
  • session 关闭前

SessionRunHook 封装了一些可重用、可组合的计算,并且可以顺便完成 session.run() 的调用。利用 Hook,我们可以为 run() 调用添加任何的 ops或tensor/feeds;并且在 run() 调用完成后获得请求的输出。Hook 可以利用 hook.begin() 方法向图中添加 ops,但请注意:在 begin() 方法被调用后,计算图就 finalized 了。

3. TF 内置了哪些 Hook?

TensorFlow 中已经内置了一些 Hook:

  • StopAtStepHook:根据 global_step 来停止训练。
  • CheckpointSaverHook:保存 checkpoint。
  • LoggingTensorHook:以日志的形式输出一个或多个 tensor 的值。
  • NanTensorHook:如果给定的 Tensor 包含 Nan,就停止训练。
  • SummarySaverHook:保存 summaries 到一个 summary writer。

4. TF 怎么自定义 Hook?

上节,我们已经介绍了预制 Hook,使用其可以实现一些常见功能。如果这些 Hook 不能满足你的需求,那么自定义 Hook 是比较好的选择。

下面是自定义 Hook 的编写模板:

class ExampleHook(tf.train.SessionRunHook):
  def begin(self):
    # You can add ops to the graph here.
    print('Starting the session.')
    self.your_tensor = ...
  def after_create_session(self, session, coord):
    # When this is called, the graph is finalized and
    # ops can no longer be added to the graph.
    print('Session created.')
  def before_run(self, run_context):
    print('Before calling session.run().')
    return SessionRunArgs(self.your_tensor)
  def after_run(self, run_context, run_values): # run_values 为 sess.run 的结果
    print('Done running one step. The value of my tensor: %s',
          run_values.results)
    if you-need-to-stop-loop:
      run_context.request_stop()
  def end(self, session):
    print('Done with the session.')

5. 怎么使用 Hook?

在那些将 session 封装起来的高阶 API 中,我们可以使用 Hook 来扩展这些这些 API 的 session.run() 的行为。

首先,我们梳理一下将 session 封装起来的高阶 API 有哪些?这些 API 包括,但不限于:

  • tf.train.MonitoredTrainingSession
  • tf.estimator.Estimator
  • tf.contrib.slim

5.1 怎么在 MonitoredTrainingSession 中使用 Hook

with tf.train.MonitoredTrainingSession(hooks=your_hooks, ...) as mon_sess:
  while not mon_sess.should_stop():
    mon_sess.run(your_fetches)

5.2 怎么在 Estimator 中使用 Hook

tf.estimator.Estimatortrainevaluatepredict 方法中都可以使用 Hook。

下面是这些方法的 API:

# 训练
# 这里的 est 是一个 Estimator 实例
est.train(input_fn, 
          hooks=None, 
          steps=None, 
          max_steps=None, 
          saving_listeners=None)
# 评估
est.evaluate(input_fn, 
             steps=None, 
             hooks=None, 
             checkpoint_path=None, 
             name=None)
# 预测
est.predict(input_fn, 
            predict_keys=None, 
            hooks=None, 
            checkpoint_path=None, 
            yield_single_examples=True)

5.3 怎么在 slim 中使用 Hook

Slim 是 TensorFlow 中一个非常优秀的高阶 API,其可以极大地简化模型的构建、训练、评估。

未完待续。。。。

6. Hook 是怎么运作的?

通过自定义 Hook 的过程,我们了解到一个 Hook 包括 beginafter_create_sessionbefore_runafter_runend 五个方法。

下面的伪代码演示了 Hook 的运行过程:

# 伪代码
call hooks.begin()
sess = tf.Session()
call hooks.after_create_session()
while not stop is requested:
  call hooks.before_run()
  try:
    results = sess.run(merged_fetches, feed_dict=merged_feeds)
  except (errors.OutOfRangeError, StopIteration):
    break
  call hooks.after_run()
call hooks.end()
sess.close()

注意:如果 sess.run() 引发 OutOfRangeErrorStopIteration 或其它异常,那么 hooks.after_run()hooks.end() 将不会被执行。

7. 内置 Hook 的研究

预制的 Hook 比较多,这里我们以 tf.train.StopAtStepHook 为例,来看看内置 Hook 是怎么编写的。

# tf.train.StopAtStepHook 的定义
class StopAtStepHook(tf.train.SessionRunHook):
  """Hook that requests stop at a specified step."""

  def __init__(self, num_steps=None, last_step=None):
    """Initializes a `StopAtStepHook`.
    This hook requests stop after either a number of steps have been
    executed or a last step has been reached. Only one of the two options can be
    specified.
    if `num_steps` is specified, it indicates the number of steps to execute
    after `begin()` is called. If instead `last_step` is specified, it
    indicates the last step we want to execute, as passed to the `after_run()`
    call.
    Args:
      num_steps: Number of steps to execute.
      last_step: Step after which to stop.
    Raises:
      ValueError: If one of the arguments is invalid.
    """
    if num_steps is None and last_step is None:
      raise ValueError("One of num_steps or last_step must be specified.")
    if num_steps is not None and last_step is not None:
      raise ValueError("Only one of num_steps or last_step can be specified.")
    self._num_steps = num_steps
    self._last_step = last_step

  def begin(self):
    self._global_step_tensor = tf.train.get_or_create_global_step()
    if self._global_step_tensor is None:
      raise RuntimeError("Global step should be created to use StopAtStepHook.")

  def after_create_session(self, session, coord):
    if self._last_step is None:
      global_step = session.run(self._global_step_tensor)
      self._last_step = global_step + self._num_steps

  def before_run(self, run_context):  # pylint: disable=unused-argument
    return tf.train.SessionRunArgs(self._global_step_tensor)

  def after_run(self, run_context, run_values):
    global_step = run_values.results + 1
    if global_step >= self._last_step:
      # Check latest global step to ensure that the targeted last step is
      # reached. global_step read tensor is the value of global step
      # before running the operation. We're not sure whether current session.run
      # incremented the global_step or not. Here we're checking it.

      step = run_context.session.run(self._global_step_tensor)
      if step >= self._last_step:
        run_context.request_stop()

8. 参考文献

  1. SessionRunHook 源码:link
  2. tf.train.SessionRunHook() 类详解:link
  3. Hook? tf.train.SessionRunHook()介绍【精】:link

猜你喜欢

转载自blog.csdn.net/u014061630/article/details/82998116