tensorflow中的一些函数,中文资料很少,需要翻墙查找外文资料,因此在这里记录下来
1.
tf.train.get_checkpoint_state
该函数被用来获取已保存的checkpoint文件状态
get_checkpoint_state(
checkpoint_dir,
latest_filename=None
)
从checkpoint文件返回CheckpointState原型。
checkpoint_dir
: checkpoints的路径latest_filename
: 可选,checkpoint文件的名称. 默认为'checkpoint'
如果状态可用,则返回CheckpointState,否则返回None。
CheckpointState包含model_checkpoint_path和all_model_checkpoint_paths
用例:
扫描二维码关注公众号,回复:
463084 查看本文章
if os.path.exists(checkpoint_dir): not_restore = ['softmax_w:0', 'softmax_b:0'] restore_var = [v for v in tf.global_variables() if v.name.split('/')[-1] not in not_restore] restored_saver = tf.train.Saver(restore_var, max_to_keep=FLAGS.check_point_every) ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: restored_saver.restore(session, ckpt.model_checkpoint_path) else: pass