官网
Tensorflow源码分析
A、基本概念
1. Graph
2. Tensor
3. Session
B、Tools
1. Checkpoint
2. Pb
3. TensorBoard
B.1
1. 模型的保存
import tensorflow as tf def store_model_ckpt(ckpt_file_path): x = tf.placeholder(tf.int32, name='x') y = tf.placeholder(tf.int32, name='y') #模型的保存必须有变量 c = tf.Variable(1, name='c') a = tf.add(x, y, name='op') result = tf.add(a, c) with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver() #如果只保存其中一部分变量,则使用下面代码,用列表或者字典都可以 #saver = tf.train.Saver([x, y]) #这里面有参数global_step=50,当训练50步便保存模型 saver.save(sess, ckpt_file_path) # test feed_dict = {x: 2, y: 3} print(sess.run(result, feed_dict)) def main(): ckpt_file_path = "./ckpt/model.ckpt" store_model_ckpt(ckpt_file_path) if __name__ == '__main__': main()
结果:6
程序生成并保存四个文件
- checkpoint 文本文件,记录了模型文件的路径信息列表
- model.ckpt.data-00000-of-00001 网络权重信息
- model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
- model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf
2. 模型恢复加载
针对上面的模型保存例子,还原模型的过程如下:
import tensorflow as tf def restore_model_ckpt(): with tf.Session() as sess: #step1:加载模型结构 saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') #step2:只需要指定目录就可以恢复所有变量信息 saver.restore(sess,tf.train.latest_checkpoint('./ckpt')) #直接获取保存的变量 print(sess.run('c:0')) #获取placeholder变量,通过get_tensor_by_name x = sess.graph.get_tensor_by_name('x:0') y = sess.graph.get_tensor_by_name('y:0') #获取需要进行计算的op算子,此op为加法 op = sess.graph.get_tensor_by_name('op:0') #加入新的op操作,新的op为乘法 new_op = tf.multiply(op, 2) #test feed_dict = {x:2, y:3} result = sess.run(new_op,feed_dict) print(result) def main(): restore_model_ckpt() if __name__ == '__main__': main()
结果:10
1. 首先还原模型结构
2. 然后还原变量(参数)信息
3. 最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。
针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。
同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。