TensorFlow二进制模型加载方法

TensorFlow二进制模型加载方法

这种加载方法一般是对应网上各大公司已经训练好的网络模型进行修改的工作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 新建空白图
self .graph  =  tf.Graph()
# 空白图列为默认图
with  self .graph.as_default():
     # 二进制读取模型文件
     with tf.gfile.FastGFile(os.path.join(model_dir,model_name), 'rb' ) as f:
         # 新建GraphDef文件,用于临时载入模型中的图
         graph_def  =  tf.GraphDef()
         # GraphDef加载模型中的图
         graph_def.ParseFromString(f.read())
         # 在空白图中加载GraphDef中的图
         tf.import_graph_def(graph_def,name = '')
         # 在图中获取张量需要使用graph.get_tensor_by_name加张量名
         # 这里的张量可以直接用于session的run方法求值了
         # 补充一个基础知识,形如'conv1'是节点名称,而'conv1:0'是张量名称,表示节点的第一个输出张量
         self .input_tensor  =  self .graph.get_tensor_by_name( self .input_tensor_name)
         self .layer_tensors  =  [ self .graph.get_tensor_by_name(name  +  ':0' for  name    in  self .layer_operation_names]

猜你喜欢

转载自blog.csdn.net/qq_30868235/article/details/80502215