caffe(六)主函数解析

tools/caffe.cpp

注册表

首先定义一个注册表,

typedef int (*BrewFunction)();
typedef std::map<caffe::string, BrewFunction> BrewMap;
BrewMap g_brew_map;

#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
 public: /* NOLINT */ \
  __Registerer_##func() { \
    g_brew_map[#func] = &func; \
  } \
}; \
__Registerer_##func g_registerer_##func; \
}

之后用RegisterBrewFunction注册了train,test和timer, 这里说明一下,这是在main函数之前就已经运行完了的,

Main函数

之后main函数如下,

int main(int argc, char** argv) {
  // Print output to stderr (while still logging).
  FLAGS_alsologtostderr = 1;
  // Set version
  gflags::SetVersionString(AS_STRING(CAFFE_VERSION));
  // Usage message.
  gflags::SetUsageMessage("command line brew\n"
      "usage: caffe <command> <args>\n\n"
      "commands:\n"
      "  train           train or finetune a model\n"
      "  test            score a model\n"
      "  device_query    show GPU diagnostic information\n"
      "  time            benchmark model execution time");
  // Run tool or show usage.
  caffe::GlobalInit(&argc, &argv);
  if (argc == 2) {
#ifdef WITH_PYTHON_LAYER
    try {
#endif
      return GetBrewFunction(caffe::string(argv[1]))();  // 程序的主要入口
#ifdef WITH_PYTHON_LAYER
    } catch (bp::error_already_set) {
      PyErr_Print();
      return 1;
    }
#endif
  } else {
    gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
  }
}

GetBrewFunction

开始调用GetBrewFunction函数,

static BrewFunction GetBrewFunction(const caffe::string& name) {
  if (g_brew_map.count(name)) {
    return g_brew_map[name];
  } else {
    LOG(ERROR) << "Available caffe actions:";
    for (BrewMap::iterator it = g_brew_map.begin();
         it != g_brew_map.end(); ++it) {
      LOG(ERROR) << "\t" << it->first;
    }
    LOG(FATAL) << "Unknown action: " << name;
    return NULL;  // not reachable, just to suppress old compiler warnings.
  }
}

比如caffe的参数为 train –solver=examples/mnist/lenet_solver.prototxt,那么argv[1]为train,argv[2]为–solver=examples/mnist/lenet_solver.prototxt, 在main函数中传入GetBrewFunction的形参为train,因而根据GetBrewFunction的功能,会执行train函数,

train

Train函数最重要的其实就只有三行,

shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); // 根据solver_param创建一solver,调用的是solver_factory 的工厂函数,这时solver的初始化会调用net的初始化,就会打印网络       的信息和部署网络,这个时候调用CreateSolver不会创建新的注册表,只会返回与solver_param对应的solver类,至于为什么这样,可以从solver_factory中得到,

  solver->SetActionFunction(signal_handler.GetActionFunction()); // 规定动作,是从0开始训练还是从训练的  中间开始
  solver->Solve(); // solver开始训练,迭代…

test

test函数是把模型正向跑一遍,不去跑反向,下面是注解

// Instantiate the caffe net.
  Net<float> caffe_net(FLAGS_model, caffe::TEST, FLAGS_level, &stages);
  caffe_net.CopyTrainedLayersFrom(FLAGS_weights);
  LOG(INFO) << "Running for " << FLAGS_iterations << " iterations.";

  vector<int> test_score_output_id; // test_score_output_id里面也放了N个数,这N个数依次是网络输出     blob的ID
  vector<float> test_score; // test score 里面只记录N个数,N代表的输出blob的个数,
  float loss = 0;
  for (int i = 0; i < FLAGS_iterations; ++i) {
    float iter_loss;
    const vector<Blob<float>*>& result =
        caffe_net.Forward(&iter_loss); // result是output_blob,不是各个层的result
    loss += iter_loss;
    int idx = 0;
    for (int j = 0; j < result.size(); ++j) {
      const float* result_vec = result[j]->cpu_data();
      for (int k = 0; k < result[j]->count(); ++k, ++idx) {
        const float score = result_vec[k];
        if (i == 0) {
          test_score.push_back(score); // test_score初始化
          test_score_output_id.push_back(j); // test_score_output_id初始化
        } else {
          test_score[idx] += score; //test_score 叠加iteration次
        }
        const std::string& output_name = caffe_net.blob_names()[
            caffe_net.output_blob_indices()[j]];
        LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score;
      }
    }
  }
  loss /= FLAGS_iterations;
  LOG(INFO) << "Loss: " << loss;
  for (int i = 0; i < test_score.size(); ++i) {
    const std::string& output_name = caffe_net.blob_names()[
        caffe_net.output_blob_indices()[test_score_output_id[i]]];
    const float loss_weight = caffe_net.blob_loss_weights()[
        caffe_net.output_blob_indices()[test_score_output_id[i]]];
    std::ostringstream loss_msg_stream;
    const float mean_score = test_score[i] / FLAGS_iterations; 
    if (loss_weight) {
      loss_msg_stream << " (* " << loss_weight
                      << " = " << loss_weight * mean_score << " loss)";
    }
    LOG(INFO) << output_name << " = " << mean_score << loss_msg_stream.str();
  }

cacacafca

发布了36 篇原创文章 · 获赞 42 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/weixin_37688445/article/details/79255973