一、Graph Save&Parser&Load
1、函数解析
(a)、Graph Save
tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)
作用:
- save a graph proto to a file
输入参数:
- graph_or_graph_def: A
Graph
or aGraphDef
protocol buffer- logdir: Directory where to write the graph
- name: Filename for the graph
- as_text: If
True
, writes the graph as an ASCII proto返回:
- The path of the output proto file.
graph_def.SerializeToString()
作用:
- Serializes the protocol message to a binary string
返回:
- A binary string representation of the message
(b)、Graph Parser
input_graph_def.ParseFromString(self, serialized=f.read())
作用:
- Parse serialized protocol buffer data into this message(适用二进制的 pb 文件)
输入参数:
- serialized: A serialized protocol buffer data
返回:
- 解析后的 input_graph_def
text_format.Merge(text=f.read(), message=input_graph_def)
作用:
- Parses a text representation of a protocol message into a message(适用文本形式的 pb 文件)
输入参数:
- text: Message text representation
- message: A protocol buffer message to merge into
返回:
- The same message passed as argument
(c)、Graph Load
tf.import_graph_def(graph_def, input_map=None, return_elements=None, name=None)
作用:
- Imports the graph from
graph_def
into the current defaultGraph
输入参数:
- graph_def: A
GraphDef proto
containing operations to be imported into the default graph- input_map: A dictionary mapping input names (as strings) in graph_def to Tensor objects. The values of the named input tensors in the imported graph will be re-mapped to the respective Tensor values.
- return_elements:
A list of strings
containingoperation names
in graph_def that will be returned as Operation objects; and/ortensor names
in graph_def that will be returned as Tensor objects- name: A prefix that will be prepended to the names in graph_def. Note that this does not apply to imported function names. Defaults to
import
返回:
- A list of
Operation and/or Tensor objects
from the imported graph, corresponding to the names in return_elements.
2、代码示例
# (a)、graph save
v = tf.Variable(0, name='my_variable')
sess = tf.Session()
tf.train.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt', as_text=True)
# or tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')
# (b)、graph parse
def _parse_input_graph_proto(input_graph, input_binary):
if not gfile.Exists(input_graph):
print("Input graph file '" + input_graph + "' does not exist!")
return -1
# Crate an empty CraphDef object
input_graph_def = graph_pb2.GraphDef()
mode = "rb" if input_binary else "r"
with gfile.FastGFile(input_graph, mode) as f:
if input_binary:
input_graph_def.ParseFromString(f.read())
else:
text_format.Merge(f.read(), input_graph_def)
# 上面若出错,可以试试使用 utf-8 解码
# text_format.Merge(f.read().decode("utf-8"), input_graph_def)
return input_graph_def
# (c)、graph load
if input_graph_def:
_ = importer.import_graph_def(input_graph_def, name="")
二、Graph Freeze
1、函数解析
tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names)
输入参数:
- sess: Active TensorFlow session containing the variables
- input_graph_def: GraphDef object holding the network
- output_node_names:
a list of the names of the nodes
that you want to extract the results of your graph from返回:
- GraphDef containing a simplified version of the original.
2、代码示例
# 略过计算图中没有保存的节点
with session.Session() as sess:
var_list = {}
reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
try:
tensor = sess.graph.get_tensor_by_name(key + ":0")
except KeyError:
# This tensor doesn't exist in the graph (for example it's
# 'global_step' or a similar housekeeping element) so skip it.
continue
var_list[key] = tensor
saver = saver_lib.Saver(var_list=var_list)
saver.restore(sess, input_checkpoint)
# freeze 操作
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
output_node_names.split(","),
variable_names_whitelist=variable_names_whitelist,
variable_names_blacklist=variable_names_blacklist)
# Write GraphDef to file if output path has been given.
if output_graph:
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
三、Graph Optimize
1、安装编译工具 Bazel
- 安装所需的包
sudo apt-get install pkg-config zip g++ zlib1g-dev unzip python
- 下载 Bazel
- 在 Bazel releases page on GitHub 上下载形如
bazel-<version>-installer-linux-x86_64.sh
的文件- 修改文件权限并执行安装
chmod +x bazel-<version>-installer-linux-x86_64.sh
./bazel-<version>-installer-linux-x86_64.sh --user
--user
标志表示: Bazel 安装在$HOME/bin
目录下, 并将.bazelrc
安装在$HOME/.bazelrc
- 在
~/.bashrc
最后添加可执行文件的路径
export PATH="$PATH:$HOME/bin"
2、使用 transform_graph 进行优化
"""
removes all of the nodes that aren't called during inference, shrinks expressions that are always
constant into single nodes, and optimizes away some multiply operations used during batch normalization
by pre-multiplying the weights for convolutions.
"""
# optimize graph, 要先编译一下相应的工具,编译一次就行了
bazel build tensorflow/tools/graph_transforms:transform_graph && \
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=frozen_graph.pb \
--out_graph=optimized_graph.pb \
--inputs='ExpandDims_1' \
--outputs='decode/SparseToDense' \
--transforms='
strip_unused_nodes(type=float, shape="1,48,160,3") # 注意这里要改成自己输入的大小
remove_nodes(op=Identity, op=CheckNumerics)
fold_constants(ignore_errors=true)
fold_batch_norms
fold_old_batch_norms'
2018-07-06 14:33:27.616294: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying strip_unused_nodes
2018-07-06 14:33:27.623069: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying remove_nodes
2018-07-06 14:33:27.654728: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying fold_constants
2018-07-06 14:33:27.678252: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying fold_batch_norms
2018-07-06 14:33:27.683107: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying fold_old_batch_norms
四、参考资料
1、Installing Bazel on Ubuntu
2、TensorFlow Graph Transform Tool
3、TensorFlow python tools freeze_graph.py