tf.train.write_graph用法

我不是知识的生产者,我只是一个渺小的搬运工,我们都站在巨人的肩膀上


探索了一下午这玩意的用法,终于会用了,在此附上实例子

首先要明白他保存图的原理,这个里面讲的很详细,请细品

https://zhuanlan.zhihu.com/p/31308381

tf.train.write_graph这个函数可以保留节点,op,constant,但不保存variable,如果你想要保存variable,那么就要转为constant

import tensorflow as tf
import numpy as np
from tensorflow.python.platform import gfile

#生成图
input1= tf.placeholder(tf.int32,name="input")
b = tf.constant([3])
output1= tf.add(input1, b, name="output")

#保存图
with tf.Session() as sess:
    tf.train.write_graph(sess.graph_def, "./", "test.pb", False)
    print(sess.run(output1,feed_dict={input1:1}))

#读取图
with tf.Session() as sess:
    with gfile.FastGFile("./test.pb",'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

#查看图中信息,填充运行图
with tf.Session() as sess:
    input_x1 = sess.graph.get_tensor_by_name("input:0")  
    print (input_x1)   #可以看到这个placeholder的属性
    output = sess.graph.get_tensor_by_name("output:0")
    print (output)
    data1 = int(3)
    print(sess.run(output,feed_dict={input_x1:data1}))  #填充placeholder,然后运行图

#或者也可以直接读入图,运行
data1 = int(3)
with tf.Session() as sess:
    with gfile.FastGFile("./test.pb",'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        output = tf.import_graph_def(graph_def, input_map={'input:0':data1},             
        return_elements=['output:0'], name='a') 
        print(sess.run(output))  

解释几个问题:

1.sess.graph.get_tensor_by_name("input:0")是干什么的,为什么是input:0?

   答:是帮你获取张量的,input是节点名称,input:0是表述节点的输出的第一个张量

2.如果图中有变量,也想要保存,怎么办?

   答:保存图的时候转化成常量保存,graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])

最后:有问题欢迎指正联系

 

发布了15 篇原创文章 · 获赞 30 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/qq_40778406/article/details/104554897