今天开始研究持久化存储,对于一个模型,我们为了方便,不用把模型的源代码都拿过来,可以只要一个记录图和里面参数的文件就好。迁移学习就是这么做的,我们最后只是把输出个数修改一下就完成了
保存
首先是存储,就是搭好框架以后存进来
import tensorflow as tf
from tensorflow.python.framework import graph_util
#定义变量
v1= tf.Variable(tf.constant([12.0],shape=[1]),name="v1")
v2= tf.Variable(tf.constant([1.0],shape=[1]),name="v2")
#定义placeholder
p1=tf.placeholder(tf.float32,[None,2],name='p1')
p2=tf.placeholder(tf.float32,[None,2],name='p2')
#定义图
result=tf.multiply(v1,v2,name="result") #节点名称 result 可以从tensorboard看到
print(result.name) #张量的名称result:0
result2=v1*v2*2 #节点名称 mul_1 可以从tensorboard看到
print(result2.name) #张量的名称mul_1:0
result3=v1+v2 #节点名称 add 可以从tensorboard看到
print(result3.name) #张量的名称add:0
#计算placeholder
result4=tf.subtract(p1,p2,name='result_substract')
print(result4.name)
#存为json格式
server=tf.train.Saver()
server.export_meta_graph('json.ckpt.meda.json',as_text=True)
#写到TensorBoard可视化
writer=tf.summary.FileWriter('log',tf.get_default_graph())
writer.close()
with tf.Session() as sess:
#申明一个graph对像
tf.global_variables_initializer().run()
graph_def=tf.get_default_graph().as_graph_def()
#把变量转为常量 ,注意 !!!后面的['result','mul_1'],我们是取他的节点名称
#但是到时候我们读取的时候,是读取张量的名称,就是后面加 ':0'
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['result','mul_1','result_substract','p1','p2'] )
#写成pb文件
with tf.gfile.GFile('model.pb','wb') as f:
f.write(output_graph_def.SerializeToString())
这个可以程序中生成了graph可视化的log文件夹,在cmd中输入:
tensorboard --logdir= log
然后在chrome浏览器中输入网址 http://localhost:6006 就可以从上图就可以看到图像结构
读取
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
with gfile.FastGFile('model.pb','rb') as f:
graph_def =tf.GraphDef()
graph_def.ParseFromString(f.read())
#读取普通的变量
result=tf.import_graph_def(graph_def,return_elements= ['result:0','mul_1:0'])
print(sess.run(result))
#对于placeholder,要从模型中读进来,再给赋值后计算
result4,p1,p2 = tf.import_graph_def(graph_def, return_elements=['result_substract:0','p1:0','p2:0'])
print(sess.run(result4,{p1:[[3,3],[4,4]],p2:[[1,1],[2,2]]}))
#输出:
[array([ 12.], dtype=float32), array([ 24.], dtype=float32)]
[[ 2. 2.]
[ 2. 2.]]