caffe源码解析:网络层的过滤Net<Dtype>::FilterNet及过滤规则StateMeetsRule

过滤规则

Net<Dtype>::StateMeetsRule函数

作用:StateMeetsRule()中net的state是否满足NetStaterule
用构造net时的输入phase/level/stage与prototxt中各层的规则(include/exclude)比较,决定本层是否要包含在net中判断rule是否相同,分为5个判断
1. Phase: train, test,
比如train的layer不适用于test
2. Min_level:本层level不小于min_level,则满足包含条件
3. Max_level:本层leve不大于max_leve,则满足包含条件
4. Stage: stage能在NetStateRule::stage中找到,则包含本层
5. Non_stage: stages能在NetStateRule::non_stage中找到,则排除本层

解释

在caffe中,所有参数结构定义在caffe.proto中,由protobuf的protoc.exe生成caffe.pb.c及caffe.pb.h,从而对数据结构结构进行管理。在使用时,网络结构往往会定义在一个<project_name>.prototxt的文件中。在定义net网络结构的prototxt文件中往往会定义某层的include/exclude参数,以确定该层网络是否要包含在某些具体的结构中或排除在外。顾名思义,include表示如果在构造net时如果满足include的条件,本层就包含在net中;exclude表示在构造net时如果满足exclude条件,本层就不会包含在net中。

管理这个被读取后的include还是exclude参数的,就是caffe.proto中的NetStateRule类,类中有phase、min_level、max_level、stage、not_stage 5个参数,也就是我们所说的过滤得规则。这些过滤规则往往是在网络构造时传入的(即:构造net时的输入参数),可用如下的方法来构造一个新net:

Net<Dtype>::Net(const string& param_file, Phase phase, const int level, const vector<string>* stages, const Net* root_net)

对于包含include参数的层:如果满足min_level<level<max_level 或 stages中任意一个元素能在NetStateRule::stage中找到, 该层就会被保留在net中

对于包含exclude参数的层:如果满足min_level<level<max_level 或 stages中任意一个元素能在NetStateRule::stage中找到, 该层就会从net中剔除

当然如果是在NetStateRule::not_stage中找到, 结果正好相反,看下面的列子,

layer {
  
name: "mnist"
  
type: "Data"
  top: "data"
  top: "label"
  include {
  phase: TEST
    not_stage: "predict" # 在 predict 时过滤掉这一层
 
}
  transform_param {
    
scale: 0.00390625
  
}
  
data_param {
    source: "examples/mnist/mnist_test_lmdb"
    batch_size: 100
    backend: LMDB
  }
}

# 增加 deploy 的输入层
layer {

  name: "data"
  type: "Input"
  top: "data"
  input_param { shape: { dim: 1 dim: 1 dim: 28 dim: 28 } }
  exclude {
    phase: TEST
    stage: "predict" # 在 predict 时不加上这一层
  }
}
如果想进一步了解对参数进行过滤有什么实际用处,我推荐这篇文章< Caffe 神经网络配置 - All in one network >:

https://yangwenbo.com/articles/caffe-net-config-all-in-one.html?utm_source=tuicool&utm_medium=referral

源码注释

template <typename Dtype>
bool Net<Dtype>::StateMeetsRule(const NetState& state,
    const NetStateRule& rule, const string& layer_name) {
  // Check whether the rule is broken due to phase.
  if (rule.has_phase()) {
      if (rule.phase() != state.phase()) {
        LOG_IF(INFO, Caffe::root_solver())
            << "The NetState phase (" << state.phase()
            << ") differed from the phase (" << rule.phase()
            << ") specified by a rule in layer " << layer_name;
        return false;
      }
  }
  // Check whether the rule is broken due to min level.
  if (rule.has_min_level()) {
    if (state.level() < rule.min_level()) {
      LOG_IF(INFO, Caffe::root_solver())
          << "The NetState level (" << state.level()
          << ") is above the min_level (" << rule.min_level()
          << ") specified by a rule in layer " << layer_name;
      return false;
    }
  }
  // Check whether the rule is broken due to max level.
  if (rule.has_max_level()) {
    if (state.level() > rule.max_level()) {
      LOG_IF(INFO, Caffe::root_solver())
          << "The NetState level (" << state.level()
          << ") is above the max_level (" << rule.max_level()
          << ") specified by a rule in layer " << layer_name;
      return false;
    }
  }
  // Check whether the rule is broken due to stage. The NetState must
  // contain ALL of the rule's stages to meet it.
  for (int i = 0; i < rule.stage_size(); ++i) {
  //net构造时输入的stage中只要有一个符合stage规则,就表明本层被include
    // Check that the NetState contains the rule's ith stage.
    bool has_stage = false;
    for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
      if (rule.stage(i) == state.stage(j)) { has_stage = true; }
    }
    if (!has_stage) {
      LOG_IF(INFO, Caffe::root_solver())
          << "The NetState did not contain stage '" << rule.stage(i)
          << "' specified by a rule in layer " << layer_name;
      return false;
    }
  }
  // Check whether the rule is broken due to not_stage. The NetState must
  // contain NONE of the rule's not_stages to meet it.
  for (int i = 0; i < rule.not_stage_size(); ++i) {
  //net构造时输入的stage中只要有一个符合not_stage规则,就表明本层被exclude
    // Check that the NetState contains the rule's ith not_stage.
    bool has_stage = false;
    for (int j = 0; !has_stage && j < state.stage_size(); ++j) {
      if (rule.not_stage(i) == state.stage(j)) { has_stage = true; }
    }
    if (has_stage) {
      LOG_IF(INFO, Caffe::root_solver())
          << "The NetState contained a not_stage '" << rule.not_stage(i)
          << "' specified by a rule in layer " << layer_name;
      return false;
    }
  }
  return true;
}

网络层的过滤

Net<Dtype>::FilterNet

作用:把模型参数文件中不符合当前阶段规则的层去掉

到这里就比较容易理解了,FilterNet()根据当前给定的phase/level/stage,移除指定层
这些规则往往是在prototxt文件中引入的,例如某个网络层设置为
layer {
  name: "accuracy"
  type: "Accuracy"
  bottom: "ip2"
  bottom: "label"
  top: "accuracy"
  include {
    phase: TEST
  }
}
那么该网络只有在TEST时才会被引入。
又如Test阶段只用网络的前向,需要将设置为phase:Train的层去掉 

源码

template <typename Dtype>
void Net<Dtype>::FilterNet(const NetParameter& param,
    NetParameter* param_filtered) {
  NetState net_state(param.state());
  param_filtered→CopyFrom(param);
  //下面先清除所有的layers, 然后根据规则重新添加layers
  param_filtered→clear_layer();
  for (int i = 0; i < param.layer_size(); ++i) {
    const LayerParameter& layer_param = param.layer(i);
    const string& layer_name = layer_param.name();
    //include和exclude不能同时存在
    CHECK(layer_param.include_size() == 0 || layer_param.exclude_size() == 0)
          << "Specify either include rules or exclude rules; not both.";
    // 下面的解释:如果include_size为0,默认是include, 所以 layer_included=true
    // If no include rules are specified, the layer is included by default and
    // only excluded if it meets one of the exclude rules.
    bool layer_included = (layer_param.include_size() == 0); 
    
for (int j = 0; layer_included && j < layer_param.exclude_size(); ++j) {
  //net_state是由构造net时的输入参数组成(phase/stage/level),
  // 参考void Solver<Dtype>::InitTrainNet()及Net<Dtype>::StateMeetsRule
  //layer_param.exclude是在prototxt中设置的某层的exclude的参数
  // (max_level/min_level/stage/not_stage/phase);
  // 满足if条件就说明,本层要被exclude;
      if (StateMeetsRule(net_state, layer_param.exclude(j), layer_name)) {
        // 如果不包含include,只要meet一个include_size(idx)即可
        layer_included = false; 
      }
    }
for (int j = 0; !layer_included && j < layer_param.include_size(); ++j) {
      // 满足条件就说明,本层要被include
      if (StateMeetsRule(net_state, layer_param.include(j), layer_name)) {
        //如果包含include,只要符合一个include_size(idx)即可
        layer_included = true; 
      }
    }
    if (layer_included) {
      param_filtered->add_layer()->CopyFrom(layer_param);
    }
  }
}

猜你喜欢

转载自blog.csdn.net/tanmx219/article/details/82918582