Tensorflow源码分析

官网

Tensorflow源码分析

A、基本概念

  1. Graph

  2. Tensor 

  3. Session

B、Tools

  1. Checkpoint

  2. Pb

  3. TensorBoard

 B.1

1. 模型的保存

import tensorflow as tf

def store_model_ckpt(ckpt_file_path):
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    #模型的保存必须有变量
    c = tf.Variable(1, name='c')
    a = tf.add(x, y, name='op')
    result = tf.add(a, c)

    with tf.Session() as sess:
        init = tf.global_variables_initializer()
        sess.run(init)
    
        saver = tf.train.Saver()
    
        #如果只保存其中一部分变量,则使用下面代码,用列表或者字典都可以
        #saver = tf.train.Saver([x, y])
    
        #这里面有参数global_step=50,当训练50步便保存模型
        saver.save(sess, ckpt_file_path)
        # test
        feed_dict = {x: 2, y: 3}
        print(sess.run(result, feed_dict))

def main():
    ckpt_file_path = "./ckpt/model.ckpt"
    store_model_ckpt(ckpt_file_path)

if __name__ == '__main__':
    main()

结果:6

程序生成并保存四个文件

  1. checkpoint 文本文件,记录了模型文件的路径信息列表
  2. model.ckpt.data-00000-of-00001 网络权重信息
  3. model.ckpt.index .data和.index这两个文件是二进制文件,保存了模型中的变量参数(权重)信息
  4. model.ckpt.meta 二进制文件,保存了模型的计算图结构信息(模型的网络结构)protobuf

2. 模型恢复加载

针对上面的模型保存例子,还原模型的过程如下:

import tensorflow as tf

def restore_model_ckpt():
    with tf.Session() as sess:
        #step1:加载模型结构
        saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')
        #step2:只需要指定目录就可以恢复所有变量信息
        saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))
        
        #直接获取保存的变量
        print(sess.run('c:0'))
        
        #获取placeholder变量,通过get_tensor_by_name
        x = sess.graph.get_tensor_by_name('x:0')
        y = sess.graph.get_tensor_by_name('y:0')
        
        #获取需要进行计算的op算子,此op为加法
        op = sess.graph.get_tensor_by_name('op:0')
        
        #加入新的op操作,新的op为乘法
        new_op = tf.multiply(op, 2)
        
        #test
        feed_dict = {x:2, y:3}
        
        result = sess.run(new_op,feed_dict)
        print(result)

def main():
    restore_model_ckpt()
    
if __name__ == '__main__':
    main()

结果:10

  1. 首先还原模型结构

  2. 然后还原变量(参数)信息

  3. 最后我们就可以获得已训练的模型中的各种信息了(保存的变量、placeholder变量、operator等),同时可以对获取的变量添加各种新的操作(见以上代码注释)。
  并且,我们也可以加载部分模型,在此基础上加入其它操作,具体可以参考官方文档和demo。

  针对ckpt模型文件的保存与还原,stackoverflow上有一个回答解释比较清晰,可以参考。

  同时cv-tricks.com上面的TensorFlow模型保存与恢复的教程也非常好,可以参考。

猜你喜欢

转载自www.cnblogs.com/missidiot/p/10143482.html