tools/caffe.cpp
注册表
首先定义一个注册表,
typedef int (*BrewFunction)();
typedef std::map<caffe::string, BrewFunction> BrewMap;
BrewMap g_brew_map;
#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \
public: /* NOLINT */ \
__Registerer_##func() { \
g_brew_map[#func] = &func; \
} \
}; \
__Registerer_##func g_registerer_##func; \
}
之后用RegisterBrewFunction注册了train,test和timer, 这里说明一下,这是在main函数之前就已经运行完了的,
Main函数
之后main函数如下,
int main(int argc, char** argv) {
// Print output to stderr (while still logging).
FLAGS_alsologtostderr = 1;
// Set version
gflags::SetVersionString(AS_STRING(CAFFE_VERSION));
// Usage message.
gflags::SetUsageMessage("command line brew\n"
"usage: caffe <command> <args>\n\n"
"commands:\n"
" train train or finetune a model\n"
" test score a model\n"
" device_query show GPU diagnostic information\n"
" time benchmark model execution time");
// Run tool or show usage.
caffe::GlobalInit(&argc, &argv);
if (argc == 2) {
#ifdef WITH_PYTHON_LAYER
try {
#endif
return GetBrewFunction(caffe::string(argv[1]))(); // 程序的主要入口
#ifdef WITH_PYTHON_LAYER
} catch (bp::error_already_set) {
PyErr_Print();
return 1;
}
#endif
} else {
gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");
}
}
GetBrewFunction
开始调用GetBrewFunction函数,
static BrewFunction GetBrewFunction(const caffe::string& name) {
if (g_brew_map.count(name)) {
return g_brew_map[name];
} else {
LOG(ERROR) << "Available caffe actions:";
for (BrewMap::iterator it = g_brew_map.begin();
it != g_brew_map.end(); ++it) {
LOG(ERROR) << "\t" << it->first;
}
LOG(FATAL) << "Unknown action: " << name;
return NULL; // not reachable, just to suppress old compiler warnings.
}
}
比如caffe的参数为 train –solver=examples/mnist/lenet_solver.prototxt,那么argv[1]为train,argv[2]为–solver=examples/mnist/lenet_solver.prototxt, 在main函数中传入GetBrewFunction的形参为train,因而根据GetBrewFunction的功能,会执行train函数,
train
Train函数最重要的其实就只有三行,
shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); // 根据solver_param创建一solver,调用的是solver_factory 的工厂函数,这时solver的初始化会调用net的初始化,就会打印网络 的信息和部署网络,这个时候调用CreateSolver不会创建新的注册表,只会返回与solver_param对应的solver类,至于为什么这样,可以从solver_factory中得到,
solver->SetActionFunction(signal_handler.GetActionFunction()); // 规定动作,是从0开始训练还是从训练的 中间开始
solver->Solve(); // solver开始训练,迭代…
test
test函数是把模型正向跑一遍,不去跑反向,下面是注解
// Instantiate the caffe net.
Net<float> caffe_net(FLAGS_model, caffe::TEST, FLAGS_level, &stages);
caffe_net.CopyTrainedLayersFrom(FLAGS_weights);
LOG(INFO) << "Running for " << FLAGS_iterations << " iterations.";
vector<int> test_score_output_id; // test_score_output_id里面也放了N个数,这N个数依次是网络输出 blob的ID
vector<float> test_score; // test score 里面只记录N个数,N代表的输出blob的个数,
float loss = 0;
for (int i = 0; i < FLAGS_iterations; ++i) {
float iter_loss;
const vector<Blob<float>*>& result =
caffe_net.Forward(&iter_loss); // result是output_blob,不是各个层的result
loss += iter_loss;
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const float* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k, ++idx) {
const float score = result_vec[k];
if (i == 0) {
test_score.push_back(score); // test_score初始化
test_score_output_id.push_back(j); // test_score_output_id初始化
} else {
test_score[idx] += score; //test_score 叠加iteration次
}
const std::string& output_name = caffe_net.blob_names()[
caffe_net.output_blob_indices()[j]];
LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score;
}
}
}
loss /= FLAGS_iterations;
LOG(INFO) << "Loss: " << loss;
for (int i = 0; i < test_score.size(); ++i) {
const std::string& output_name = caffe_net.blob_names()[
caffe_net.output_blob_indices()[test_score_output_id[i]]];
const float loss_weight = caffe_net.blob_loss_weights()[
caffe_net.output_blob_indices()[test_score_output_id[i]]];
std::ostringstream loss_msg_stream;
const float mean_score = test_score[i] / FLAGS_iterations;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << output_name << " = " << mean_score << loss_msg_stream.str();
}
cacacafca