TensorFlow 持久化

今天开始研究持久化存储,对于一个模型,我们为了方便,不用把模型的源代码都拿过来,可以只要一个记录图和里面参数的文件就好。迁移学习就是这么做的,我们最后只是把输出个数修改一下就完成了

保存

首先是存储,就是搭好框架以后存进来


import tensorflow as tf
from tensorflow.python.framework import graph_util

#定义变量
v1=  tf.Variable(tf.constant([12.0],shape=[1]),name="v1")
v2=  tf.Variable(tf.constant([1.0],shape=[1]),name="v2")


#定义placeholder
p1=tf.placeholder(tf.float32,[None,2],name='p1')
p2=tf.placeholder(tf.float32,[None,2],name='p2')


#定义图
result=tf.multiply(v1,v2,name="result") #节点名称 result 可以从tensorboard看到
print(result.name) #张量的名称result:0
result2=v1*v2*2    #节点名称 mul_1   可以从tensorboard看到
print(result2.name)   #张量的名称mul_1:0
result3=v1+v2       #节点名称   add  可以从tensorboard看到
print(result3.name)  #张量的名称add:0

#计算placeholder
result4=tf.subtract(p1,p2,name='result_substract')
print(result4.name)


#存为json格式
server=tf.train.Saver()
server.export_meta_graph('json.ckpt.meda.json',as_text=True)

#写到TensorBoard可视化
writer=tf.summary.FileWriter('log',tf.get_default_graph())
writer.close()


with tf.Session() as sess:
    #申明一个graph对像
    tf.global_variables_initializer().run()
    graph_def=tf.get_default_graph().as_graph_def()

    #把变量转为常量 ,注意 !!!后面的['result','mul_1'],我们是取他的节点名称
    #但是到时候我们读取的时候,是读取张量的名称,就是后面加  ':0'
     output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['result','mul_1','result_substract','p1','p2'] )

    #写成pb文件
    with tf.gfile.GFile('model.pb','wb') as f:
        f.write(output_graph_def.SerializeToString())

这个可以程序中生成了graph可视化的log文件夹,在cmd中输入:

tensorboard --logdir= log

然后在chrome浏览器中输入网址 http://localhost:6006 就可以从上图就可以看到图像结构
TensorBoard

读取


import tensorflow as tf
from tensorflow.python.platform import gfile


with tf.Session() as sess:
    with gfile.FastGFile('model.pb','rb') as f:
        graph_def =tf.GraphDef()
        graph_def.ParseFromString(f.read())

    #读取普通的变量
    result=tf.import_graph_def(graph_def,return_elements= ['result:0','mul_1:0'])
    print(sess.run(result))

    #对于placeholder,要从模型中读进来,再给赋值后计算
    result4,p1,p2 = tf.import_graph_def(graph_def, return_elements=['result_substract:0','p1:0','p2:0'])
    print(sess.run(result4,{p1:[[3,3],[4,4]],p2:[[1,1],[2,2]]}))

#输出:

[array([ 12.], dtype=float32), array([ 24.], dtype=float32)]
[[ 2.  2.]
 [ 2.  2.]]

猜你喜欢

转载自blog.csdn.net/loovelj/article/details/80004372