get_tensor_by_name

1、tf.get_default_graph().get_tensor_by_name("<name>:0")

import tensorflow as tf

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')

with tf.Session() as sess:
    print (e.name)
    # example:0
    # <name>:0 (0 refers to endpoint which is somewhat redundant)
    # 形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
    test = tf.get_default_graph().get_tensor_by_name("example:0")
    print (test)
    # Tensor("example:0", shape=(2, 2), dtype=float32)

2、[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

import tensorflow as tf
import os

model_dir = 'Inception_v3'
model_name = 'output_graph.pb'

# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)
def create_graph():
    with tf.gfile.FastGFile(os.path.join(
            model_dir, model_name), 'rb') as f:
        # 使用tf.GraphDef()定义一个空的Graph
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        # Imports the graph from graph_def into the current default Graph.
        tf.import_graph_def(graph_def, name='')

# 创建graph
create_graph()

tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
    print(tensor_name,'\n')

3、FaceNet中的应用:

def load_model(model):
    # Check if the model is a model directory (containing a metagraph and a checkpoint file)
    # or if it is a protobuf file with a frozen graph
    model_exp = os.path.expanduser(model)
    if (os.path.isfile(model_exp)):
        print('Model filename: %s' % model_exp)
        with tf.gfile.FastGFile(model_exp, 'rb') as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
    else:
        ckpt = tf.train.get_checkpoint_state(model_exp)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)

with tf.Graph().as_default():

        with tf.Session() as sess:      
            # Load the model
    		load_model(model_file)
            # Get input and output tensors
            images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
            embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
            phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")

            # Run forward pass to calculate embeddings
            feed_dict = { images_placeholder: images, phase_train_placeholder:False }
            emb = sess.run(embeddings, feed_dict=feed_dict)

转载自:https://www.jianshu.com/p/3cee7ca5ebd8

猜你喜欢

转载自blog.csdn.net/baidu_27643275/article/details/82982889
今日推荐