保存模型
# 创建Saver()节点
saver = tf.train.Saver()
# 训练过程中保存节点
save_path = saver.save(sess, "./ckpt/my_model.ckpt", global_step=epoch)
# 保存最终节点
save_path = saver.save(sess, "./ckpt/my_model_final.ckpt")
读取模型
# 创建Saver()节点
saver = tf.train.Saver()
# 读取节点
ckpt = tf.train.get_checkpoint_state('./ckpt/')
# 读取模型
saver.restore(sess, ckpt.model_checkpoint_path)
使用模型
"""
前期需要将整个计算图构建出来
但不需要像训练时init参数
"""
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "./ckpt/my_model_final.ckpt")
# 测试
print(accuracy.eval({x: my_mnist.test.images, y_: my_mnist.test.labels}))
TODO:restore的时候,参数是如何对应到网络结构上的