tensorflow SessionRunHook MonitoredTrainingSession

Hook? tf.train.SessionRunHook()介绍【精】

https://blog.csdn.net/mrr1ght/article/details/81011280

tf.train.MonitoredTrainingSession()解析【精】

https://blog.csdn.net/mrr1ght/article/details/81006343

一个tf Estimator Summary Hook 函数

https://blog.csdn.net/weixin_43378751/article/details/86724265


class SessionRunHook(object):
  """Hook to extend calls to MonitoredSession.run()."""
 
  def begin(self):
    """再创建会话之前调用
    调用begin()时,default graph会被创建,
    可在此处向default graph增加新op,begin()调用后,default graph不能再被修改
    """
    pass
 
  def after_create_session(self, session, coord):  # pylint: disable=unused-argument
    """tf.Session被创建后调用
    调用后会指示所有的Hooks有一个新的会话被创建
    Args:
      session: A TensorFlow Session that has been created.
      coord: A Coordinator object which keeps track of all threads.
    """
    pass
 
  def before_run(self, run_context):  # pylint: disable=unused-argument
    """调用在每个sess.run()执行之前
    可以返回一个tf.train.SessRunArgs(op/tensor),在即将运行的会话中加入这些op/tensor;
    加入的op/tensor会和sess.run()中已定义的op/tensor合并,然后一起执行;
    Args:
      run_context: A `SessionRunContext` object.
    Returns:
      None or a `SessionRunArgs` object.
    """
    return None
  def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):  # pylint: disable=unused-argument
    """调用在每个sess.run()之后
    参数run_values是befor_run()中要求的op/tensor的返回值;
    可以调用run_context.qeruest_stop()用于停止迭代
    sess.run抛出任何异常after_run不会被调用
    Args:
      run_context: A `SessionRunContext` object.
      run_values: A SessionRunValues object.
    """
    pass
 
  def end(self, session):  # pylint: disable=unused-argument
    """在会话结束时调用
    end()常被用于Hook想要执行最后的操作,如保存最后一个checkpoint
    如果sess.run()抛出除了代表迭代结束的OutOfRange/StopIteration异常外,
    end()不会被调用
    Args:
      session: A TensorFlow Session that will be soon closed.
    """
    pass
————————————————
版权声明:本文为CSDN博主「朔方_」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/mrr1ght/article/details/81011280

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""
 
  def begin(self):
    self._step = -1
    self._start_time = time.time()
 
  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.
 
  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time#duration持续的时间
      self._start_time = current_time
 
      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)
 
      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,

————————————————
版权声明:本文为CSDN博主「朔方_」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/mrr1ght/article/details/81011280
发布了71 篇原创文章 · 获赞 13 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/Lord_sh/article/details/104801166