TensorFlow 图变换:FoldBatchNorm

TensorFlow 的计算是按图(Graph)组织的,构建好的图有时需要根据需要做一些变换(例如将训练好的模型部署到生产环境时,去除无用的节点),在保证计算结果不变(或近似不变)的情况下优化计算速度或降低内存占用。Graph Transform Tool【1】是 TensorFlow 提供的一组可以修改 TensorFlow Graph 的工具,使用方便,易于扩展。

使用 Graph Transform Tool 时,它的操作对象为 GraphDef 对象,通常保存为二进制文件,后缀为 .pb。前面文章《TensorFlow 到底有几种模型格式?》介绍过这种文件的生成方式。

该工具调用方式如下:

bazel build tensorflow/tools/graph_transforms:transform_graph bazel-bin/tensorflow/tools/graph_transforms/transform_graph \ --in_graph=tensorflow_inception_graph.pb \ --out_graph=optimized_inception_graph.pb \ --inputs='Mul:0' \ --outputs='softmax:0' \ --transforms=' strip_unused_nodes(type=float, shape="1,299,299,3") remove_nodes(op=Identity, op=CheckNumerics) fold_constants fold_batch_norms'

注意需要在 TensorFlow 源码根目录下运行。

其中参数 --in_graph 指定输入 GraphDef 文件名,--out_graph 制定输出 GraphDef 文件名,--inputs 指定输入 Node,--outputs 指定输出 Node,--transforms 指定变换类型。变换类型使用一串命令构成,每条命令都对应一种变换。

Batch Normalization(后面简称 BN)【2】是一种加速深度模型训练的技术,通过训练时对每个 mini-batch 内的 activations 做归一化降低 internal covariate shift,进而加速模型收敛。目前主流深度学习模型(ResNet,Inception,DenseNet,……)几乎都使用了 BN 技术。训练完毕,Batch Normalization 的参数(均值 E[x] 和方差 Val[x])不再更新,在后续推理计算时,可将这些常数参数通过 constant folding 来简化模型。BN 计算公式如下:

一般 BN 位置都在 Convolution 之后(DenseNet 例外),以 TensorFlow 实现的 Inception V3 模型【3】 为例,Conv-BN-Relu 表示为计算图如下:

其中 Conv2D 节点实现 Convolution 计算,Rsqrt 实现先求平方根再取倒数运算,Mul, Add, Sub 分别实现乘法、加法、减法计算,Const 为常数,代表计算参数。由于推理计算时 BN 输入参数均为常数,那么经过 constant folding, BN 可在算数上简化为:

y = x * a + b

进一步,当 x 为卷积输出,对卷积权值直接乘上 a,就可以在前向计算时直接得到 x * a 的结果,这一步称为 BN folding。经过两步简化后的计算图为:

此时节点数目也有大量缩减。在推理计算时,能降低运行时间和存储开销。

与上面 BN folding 优化对应的 Graph Transform 代码位于 tensorflow/tools/graph_transforms/fold_batch_norms.cc,其中使用了一个非常有用的函数:

Status ReplaceMatchingOpTypes(

    const GraphDef& input_graph_def, 

    const OpTypePattern& pattern,

    const std::function<Status(const NodeMatch&, const std::set<string>&,

    const std::set<string>&, std::vector<NodeDef>*)>&  node_generator,

    const ReplaceMatchingOpTypesOptions& options, 

    GraphDef* output_graph_def);

该函数将 input_graph_def 中所有与 pattern 匹配的子图替换为 node_generator 产生的新 op,然后保存到 output_graph_def 中。

pattern 定义为:

      {"Mul",                // mul_node

        {

          {"Conv2D|MatMul",  // conv_node

            {

              {"*"},         // input_node

              {"Const"},     // weights_node

            }

          },

          {"Const"},         // mul_values_node

        }

      },  // clang-format on

上述 pattern 能匹配原计算图中 Conv2D -> Mul 子图。node_generator 代码中将匹配后的子图直接替换为新的 Conv2D(权值常量更新为原权值与乘数因子 a 的乘积)。代码如下:

        // 从匹配模式中得到 Mul、Conv、Input、weight、Mul value 节点

        const NodeDef& mul_node = match.node;

        const NodeDef& conv_node = match.inputs[0].node;

        const NodeDef& input_node = match.inputs[0].inputs[0].node;

        const NodeDef& weights_node = match.inputs[0].inputs[1].node;

        const NodeDef& mul_values_node = match.inputs[1].node;

        // 获取卷积权值、乘数因子数值

        Tensor weights = GetNodeTensorAttr(weights_node, "value");

        Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value");

        // 原始卷积权值乘上乘数因子

        auto weights_matrix = weights.flat_inner_dims<float>();

        Tensor scaled_weights(DT_FLOAT, weights.shape());

        auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();

        for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {

          for (int64 col = 0; col < weights_cols; ++col) {

            scaled_weights_matrix(row, col) =

                weights_matrix(row, col) * mul_values.flat<float>()(col);

          }

        }

        // 构造新的卷积权值节点,填入更新后的权值

        NodeDef scaled_weights_node;

        scaled_weights_node.set_op("Const");

        scaled_weights_node.set_name(weights_node.name());

        SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);

        SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);

        new_nodes->push_back(scaled_weights_node);

        new_nodes->push_back(input_node);

        // 构造新的卷积节点,复制旧卷积节点参数,改个名

        NodeDef new_conv_node;

        new_conv_node = conv_node;

        new_conv_node.set_name(mul_node.name());

        new_nodes->push_back(new_conv_node);

        return Status::OK();

Graph Transform 工具为离线优化工具,优化后的 GraphDef 文件可以像原先模型一样部署,无需修改生产环境代码。

本文介绍的 BN folding 优化方法可适用于 CPU、GPU、移动端、嵌入式等各种需要推理加速的场景。

【1】 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms

【2】 Batch Normalization : Accelerating Deep Network Training by Reducing Internal Covariate Shift, arXiv:1502.03167

【3】 http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz

参考:https://mp.weixin.qq.com/s/dLj7jA4rPg2UYzoe8SkgOQ

猜你喜欢

转载自blog.csdn.net/weixin_41770169/article/details/90075673