1、tf.get_default_graph().get_tensor_by_name("<name>:0")
import tensorflow as tf
c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')
with tf.Session() as sess:
print (e.name)
# example:0
# <name>:0 (0 refers to endpoint which is somewhat redundant)
# 形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
test = tf.get_default_graph().get_tensor_by_name("example:0")
print (test)
# Tensor("example:0", shape=(2, 2), dtype=float32)
2、[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
import tensorflow as tf
import os
model_dir = 'Inception_v3'
model_name = 'output_graph.pb'
# 读取并创建一个图graph来存放Google训练好的Inception_v3模型(函数)
def create_graph():
with tf.gfile.FastGFile(os.path.join(
model_dir, model_name), 'rb') as f:
# 使用tf.GraphDef()定义一个空的Graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Imports the graph from graph_def into the current default Graph.
tf.import_graph_def(graph_def, name='')
# 创建graph
create_graph()
tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
for tensor_name in tensor_name_list:
print(tensor_name,'\n')
3、FaceNet
中的应用:
def load_model(model):
# Check if the model is a model directory (containing a metagraph and a checkpoint file)
# or if it is a protobuf file with a frozen graph
model_exp = os.path.expanduser(model)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with tf.gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
ckpt = tf.train.get_checkpoint_state(model_exp)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
with tf.Graph().as_default():
with tf.Session() as sess:
# Load the model
load_model(model_file)
# Get input and output tensors
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
# Run forward pass to calculate embeddings
feed_dict = { images_placeholder: images, phase_train_placeholder:False }
emb = sess.run(embeddings, feed_dict=feed_dict)