加载 已训练模型 张量的 几种方法

1、saver.restore()

用 saver.restore()加载 模型 之前,首先要 定义 模型 的 计算图,具体操作如下:

#重新定义 计算图
def define_graph(input):
	graph_define
	return computed_tensor #返回需要计算的张量

with tf.Session() as sess:
	#input = tf.placeholder(tf.float,shape,name="input")
	computed_tensor = define_graph(input)
	#加载模型
	saver = tf.train.Saver()
	saver.restore(sess, model.ckpt)

	#计算 张量
	sess.run(computed_tensor, feed_dict={
    
    input:input_data})
2、saver.restore() + tf.train.import_meta_graph()
#利用 tf.train.import_meta_graph()载入计算图
saver = tf.train.import_meta_graph('model.ckpt.meta')
with tf.Session() as sess:
	#载入模型
	saver.restore(sess,'model.ckpt')
	#载入 要计算的张量名
	input, computed_tensor = tf.get_default_graph().get_tensor_by_name(['input', 'computed_tensor:0'])
	# 计算张量
	sess.run(computed_tensor, feed_dict = {
    
    input:input_data})
3、gfile.GFile()
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util

#保存模型,及要计算的 tensor
graph_def = tf.get_default_graph().as_graph_def()
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['input', 'computed_tensor'])                                                                
with gfile.GFile('model.pb', "wb") as f:
	f.write(output_graph_def.SerializeToString())
	
#加载模型,以及要计算的 tensor
    with tf.Session() as sess:
        with gfile.FastGFile('model.pb', "rb") as f: 
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
        	input, computed_tensor = tf.import_graph_def(graph_def,return_elements=["input", "computed_tensor:0"])
        	sess.run(computed_tensor, feed_dict = {
    
    input:input_data})

猜你喜欢

转载自blog.csdn.net/u014765410/article/details/100539837