版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/m_buddy/article/details/84864741
1. 前言
在之前的文章中讲到了Caffe中是如何把图像数据转换成为LMDB与H5格式文件的。那么Caffe中是怎么实现由这些文件读取到网络中进行训练的呢?其实Caffe中是有专门的数据读取层的,用来读取不同的数据类型。下面是Caffe中主要数据读取类的关系图:
平时用得比较多的是DataLayer与ImageDataLayer(读取图像效率低也不怎么使用)。这里就主要从DataLayer进行分析。至于H5文件的读取是通过另外一个类实现的,这个在后面的讲解中说到。
2. LMDB读取
对于读取LMDB数据类型使用的DataLayer类,其中通过backend参数指定了数据的类型,并在构造函数里面就对其进行了初始化
template <typename Dtype>
DataLayer<Dtype>::DataLayer(const LayerParameter& param)
: BasePrefetchingDataLayer<Dtype>(param),
offset_() {
db_.reset(db::GetDB(param.data_param().backend())); //通过网络参数得到DB的类型,并进行初始化
db_->Open(param.data_param().source(), db::READ); //打开DB对象
cursor_.reset(db_->NewCursor());
}
接下来在DataLayerSetUp函数中读取一个数据来初始化prefetch中数据存储单元与当前层的输出blob的维度,后序再在线程函数中去读取训练数据
template <typename Dtype>
void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const int batch_size = this->layer_param_.data_param().batch_size(); //当前网络设置的batch size大小
// Read a data point, and use it to initialize the top blob. 在数据库中读取一个datum去初始化top blob
Datum datum;
datum.ParseFromString(cursor_->value());
// Use data_transformer to infer the expected blob shape from datum.
// 根据datum的维度信息来设置top blob的C*W*H,要是有Corp参数需要按照Corp的参数来
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
// Reshape top[0] and prefetch_data according to the batch_size.
top_shape[0] = batch_size; //设置top blob的batch size
top[0]->Reshape(top_shape);
// 根据网络prefetch的设置来预先读取数据
for (int i = 0; i < this->prefetch_.size(); ++i) {
this->prefetch_[i]->data_.Reshape(top_shape);
}
LOG_IF(INFO, Caffe::root_solver())
<< "output data size: " << top[0]->num() << ","
<< top[0]->channels() << "," << top[0]->height() << ","
<< top[0]->width();
// label 对应label的维度设置
if (this->output_labels_) {
vector<int> label_shape(1, batch_size);
top[1]->Reshape(label_shape);
for (int i = 0; i < this->prefetch_.size(); ++i) {
this->prefetch_[i]->label_.Reshape(label_shape);
}
}
}
线程读取函数,用于读取一个batch的数据
// 这个韩式有prefetch队列的线程调用,用于加载batch数据
template<typename Dtype>
void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
CPUTimer timer;
CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
const int batch_size = this->layer_param_.data_param().batch_size();
Datum datum;
for (int item_id = 0; item_id < batch_size; ++item_id) {
timer.Start();
while (Skip()) {
Next();
}
datum.ParseFromString(cursor_->value());
read_time += timer.MicroSeconds();
if (item_id == 0) { //后序的数据都是按照第一个数据的维度来确定
// Reshape according to the first datum of each batch
// on single input batches allows for inputs of varying dimension.
// Use data_transformer to infer the expected blob shape from datum.
vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);
this->transformed_data_.Reshape(top_shape);
// Reshape batch according to the batch_size.
top_shape[0] = batch_size;
batch->data_.Reshape(top_shape);
}
// Apply data transformations (mirror, scale, crop...) 使用图像变换操作
timer.Start();
int offset = batch->data_.offset(item_id);
Dtype* top_data = batch->data_.mutable_cpu_data();
this->transformed_data_.set_cpu_data(top_data + offset);
this->data_transformer_->Transform(datum, &(this->transformed_data_));
// Copy label.标签数据
if (this->output_labels_) {
Dtype* top_label = batch->label_.mutable_cpu_data();
top_label[item_id] = datum.label();
}
trans_time += timer.MicroSeconds();
Next();
}
timer.Stop();
batch_timer.Stop();
DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
DLOG(INFO) << " Read time: " << read_time / 1000 << " ms.";
DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";
}
3. H5读取
H5文件的读取与LMDB的读取方式类似,H5存储的形式类似于键值对的形式,并且在里面对数据进行了充分的shuffle操作,避免了生成数据阶段没有shuffle而造成的训练失败的情况。
顺带说一句,要是图像分类中训练的图片没有shuffle会存在什么为题呢?-_-||,那就是输出的分类概率就跟撞天婚一样,而且学不动…
对于H5文件的读取,首先读取其list文件,并且按照键值对得到输出数据的维度,并以此来设置当前层的输出维度。
template <typename Dtype>
void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Refuse transformation parameters since HDF5 is totally generic.
CHECK(!this->layer_param_.has_transform_param()) <<
this->type() << " does not transform data.";
// Read the source to parse the filenames.
const string& source = this->layer_param_.hdf5_data_param().source(); //h5文件列表的地址
LOG(INFO) << "Loading list of HDF5 filenames from: " << source;
hdf_filenames_.clear(); //H5文件列表
//读取文件中的所有的h5文件(文件是绝对路径存在),存到一个vector<string>中去
std::ifstream source_file(source.c_str());
if (source_file.is_open()) {
std::string line;
while (source_file >> line) {
hdf_filenames_.push_back(line);
}
} else {
LOG(FATAL) << "Failed to open source file: " << source;
}
source_file.close();
num_files_ = hdf_filenames_.size(); //所有H5文件的个数
current_file_ = 0;
LOG(INFO) << "Number of HDF5 files: " << num_files_;
CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in "
<< source;
file_permutation_.clear();
file_permutation_.resize(num_files_); //训练的H5文件排队列表,后面将其打乱
// Default to identity permutation.
for (int i = 0; i < num_files_; i++) {
file_permutation_[i] = i;
}
// Shuffle if needed. 打乱文件排序
if (this->layer_param_.hdf5_data_param().shuffle()) {
std::random_shuffle(file_permutation_.begin(), file_permutation_.end());
}
// Load the first HDF5 file and initialize the line counter. 读取H5文件
LoadHDF5FileData(hdf_filenames_[file_permutation_[current_file_]].c_str());
current_row_ = 0;
// Reshape blobs.
const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); //当前层的batch size
const int top_size = this->layer_param_.top_size(); // 当前层的是输出个数,默认第一个为训练data,第二个为label
vector<int> top_shape;
//按照读取数据的维度,设置输出blob的维度
for (int i = 0; i < top_size; ++i) {
top_shape.resize(hdf_blobs_[i]->num_axes());
top_shape[0] = batch_size;
for (int j = 1; j < top_shape.size(); ++j) {
top_shape[j] = hdf_blobs_[i]->shape(j);
}
top[i]->Reshape(top_shape);
}
}
接下来就是读取H5文件了,每次都会把H5文件中的图片和lable全部读取完的,全部放在内存中,一次取一个batch
template <typename Dtype>
void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// Refuse transformation parameters since HDF5 is totally generic.
CHECK(!this->layer_param_.has_transform_param()) <<
this->type() << " does not transform data.";
// Read the source to parse the filenames.
const string& source = this->layer_param_.hdf5_data_param().source(); //h5文件列表的地址
LOG(INFO) << "Loading list of HDF5 filenames from: " << source;
hdf_filenames_.clear(); //H5文件列表
//读取文件中的所有的h5文件(文件是绝对路径存在),存到一个vector<string>中去
std::ifstream source_file(source.c_str());
if (source_file.is_open()) {
std::string line;
while (source_file >> line) {
hdf_filenames_.push_back(line);
}
} else {
LOG(FATAL) << "Failed to open source file: " << source;
}
source_file.close();
num_files_ = hdf_filenames_.size(); //所有H5文件的个数
current_file_ = 0;
LOG(INFO) << "Number of HDF5 files: " << num_files_;
CHECK_GE(num_files_, 1) << "Must have at least 1 HDF5 filename listed in "
<< source;
file_permutation_.clear();
file_permutation_.resize(num_files_); //训练的H5文件排队列表,后面将其打乱
// Default to identity permutation.
for (int i = 0; i < num_files_; i++) {
file_permutation_[i] = i;
}
// Shuffle if needed. 打乱文件排序
if (this->layer_param_.hdf5_data_param().shuffle()) {
std::random_shuffle(file_permutation_.begin(), file_permutation_.end());
}
// Load the first HDF5 file and initialize the line counter. 读取H5文件
LoadHDF5FileData(hdf_filenames_[file_permutation_[current_file_]].c_str());
current_row_ = 0;
// Reshape blobs.
const int batch_size = this->layer_param_.hdf5_data_param().batch_size(); //当前层的batch size
const int top_size = this->layer_param_.top_size(); // 当前层的是输出个数,默认第一个为训练data,第二个为label
vector<int> top_shape;
//按照读取数据的维度,设置输出blob的维度
for (int i = 0; i < top_size; ++i) {
top_shape.resize(hdf_blobs_[i]->num_axes());
top_shape[0] = batch_size;
for (int j = 1; j < top_shape.size(); ++j) {
top_shape[j] = hdf_blobs_[i]->shape(j);
}
top[i]->Reshape(top_shape);
}
}
这是H5文件的读入,Caffe中还实现了H5文件的输出,有兴趣的可以参考HDF5OutputLayer类的实现。