版权声明:如需转载请评论告知并标明出处 https://blog.csdn.net/ShuqiaoS/article/details/83374063
DispNet中Caffe自定义层解读(一)—— CustomData
这一系列博文记录了博主在学习DispNet过程中遇到的自定义Caffe层的笔记。这一部分是CustomData层,其主要功能是:读取数据库中的LMDB类型数据,并将其随机排布后存入top。更新于2018.10.25。
文章目录
调用方式
layer {
name: "CustomData1"
type: "CustomData"
top: "blob0"
top: "blob1"
top: "blob2"
include {
phase: TRAIN
}
data_param {
source: "path/to/your/data/lmdb"
batch_size: 32
backend: LMDB
rand_permute: true
rand_permute_seed: 77
slice_point: 3
slice_point: 6
encoding: UINT8
encoding: UINT8
encoding: UINT16FLOW
verbose: true
}
}
custom_data_layer.hpp
定义了LMDB数据类型的几个变量:
// LMDB
MDB_env* mdb_env_;
MDB_dbi mdb_dbi_;
MDB_txn* mdb_txn_;
MDB_cursor* mdb_cursor_;
MDB_val mdb_key_, mdb_value_;
小知识:
- 什么是句柄? 简单来说,如果从一个数可以“拎出”很多东西,那么这个数就是句柄。
- MDB_env*中的*号代表什么: 表示指向前面那个类型的指针,可以理解为mdb_env这个变量里放的地址,谁赋值给这个变量,就放的谁的地址。(注:解释来自C++大牛@langzai310)
- MDB_env: 为数据库环境(database environment)定义的一个不透明结构体。官网解释:Opaque structure for a database environmen。更多官网解释看这里。
- MDB_dbi: DB环境下的个人数据库(individual database)的句柄。官网解释:A handle for an individual database in the DB environment。
- MDB_txn: 为一个事务句柄(transaction handle)定义一个不透明结构体(Opaque structure)。官网解释:Opaque structure for a transaction handle。更多官网解释看这里。
- MDB_cursor: 为巡航一个数据库定义的不透明结构体。官网解释:Opaque structure for navigating through a database。更多官网解释看这里。
- MDB_val: 用于将keys和数据传入、传出数据库的语类结构(generic structure)。
定义变量用于声明线程ID:
pthread_t thread_;
定义智能指针:
shared_ptr<Blob<Dtype> > prefetch_label_;
vector<shared_ptr<Blob<Dtype> > > prefetch_data_blobs_;
还有其他Caffe中通用的函数及变量定义,此处不作赘述。
custom_data_layer.cpp
Forward_cpu
用函数JoinPrefetchThread();
将线程joint在一起,并检查是否成功。如果失败,返回"Pthread joining failed."
。
注:整合用到函数pthread_join
,用来等待一个线程的结束,线程间同步的操作。头文件 : #include <pthread.h>。具体描述看这里。
将cpu中的数据复制到top中:
for (int i = 0; i <= slice_point_.size(); ++i) {
// Copy the data
caffe_copy(prefetch_data_blobs_[i]->count(), prefetch_data_blobs_[i]->cpu_data(), top[i]->mutable_cpu_data());
}
其中,prefetch_data_blobs_[i]->cpu_data()
为被复制的源数据,top[i]->mutable_cpu_data()
为复制到的目标数据。
如果output_labels(头文件中定义的bool类型变量)为真,地址移动到整个slice的下一个地址(label_topblob = slice_point_.size() + 1;
),将prefetch_label_
中的内容复制给top[label_topblob]
。
随机排布输入的图像。
iter_++;
if (this->layer_param_.data_param().rand_permute() && this->layer_param_.data_param().permute_every_iter()) { //如果层参数中设置了rand_permute(true)和permute_every_iter
if (iter_ % this->layer_param_.data_param().permute_every_iter() == 0) { //如果permute_every_iter为0
generateRandomPermutation(-1, this->layer_param_.data_param().block_size()); //调用generateRandomPermutation函数,参数为-1和层参数中的block_size
if (this->layer_param_.data_param().verbose()) { //如果需要,将新的排布顺序显示出来。
printf("Re-permuting at iteration %d. Permutation:\n", iter_);
for(int j = 0; j < permutation_vector_.size(); j++) {
printf("%d ",permutation_vector_.at(j));
}
printf("\n");
}
}
}
生成一个新的线程。
// Start a new prefetch thread
CreatePrefetchThread();
generateRandomPermutation
在Forward_cpu中,这个函数的输入为-1和层参数block_size的值。此时函数的功能是:将所有输入的图像重新随机排布。
template <typename Dtype>
void CustomDataLayer<Dtype>::generateRandomPermutation(int seed, int block_size) {
if (seed > 0) //如果seed大于0,根据seed初始化随机数发生器。(srand函数是随机数发生器的初始化函数)
std::srand (unsigned(seed));
if (block_size > 0) { //如果block_size大于0,
int num_blocks = (permutation_vector_.size() + block_size - 1) / block_size; // equal to ceil(size / block_size)
for (int b=0; b < num_blocks; ++b) {
int n1 = b * block_size;
int n2 = std::min((b+1)*block_size, static_cast<int>(permutation_vector_.size()));
std::random_shuffle(permutation_vector_.begin() + n1, permutation_vector_.begin() + n2 -1);
}
} else { //否则,将permutation_vector_中的数随机排列。变量定义在hpp中:std::vector<int> permutation_vector_;
std::random_shuffle(permutation_vector_.begin(), permutation_vector_.end());
}
}
CreatePrefetchThread
template <typename Dtype>
void CustomDataLayer<Dtype>::CreatePrefetchThread() {
const bool prefetch_needs_rand = (this->phase_ == TRAIN) &&
(this->layer_param_.data_param().mirror() ||
this->layer_param_.data_param().crop_size());
if (prefetch_needs_rand) {
const unsigned int prefetch_rng_seed = caffe_rng_rand();
prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));
} else {
prefetch_rng_.reset();
}
// Create the thread.
CHECK(!pthread_create(&thread_, NULL, CustomDataLayerPrefetch<Dtype>,
static_cast<void*>(this))) << "Pthread execution failed.";
}
DecodeData
template <typename Dtype>
void DecodeData(Dtype*& ptr,Datum& datum,const vector<int>& slice_points,const vector<int>& encoding)
{
int width=datum.width();
int height=datum.height();
int channels=datum.channels();
int count=width*height*channels;
ptr=new Dtype[count];
if(datum.float_data_size())
{
CHECK_EQ(encoding.size(),0) << "Encoded layers must be stored as uint8 in LMDB.";
for(int i=0; i<count; i++)
ptr[i]=datum.float_data(i);
return;
}
const unsigned char* srcptr=(const unsigned char*)datum.data().c_str();
Dtype* destptr=ptr;
int channel_start = -1; //inclusive
int channel_end = 0; //non-inclusive (end will become start in next slice)
for(int slice = 0; slice <= slice_points.size(); slice++)
{
channel_start = channel_end;
if(slice == slice_points.size())
channel_end = channels;
else
channel_end = slice_points[slice];
int channel_count=channel_end-channel_start;
int format;
if(encoding.size()<=slice)
format=DataParameter_CHANNELENCODING_UINT8;
else
format=encoding[slice];
// LOG(INFO) << "Slice " << slice << "(" << channel_start << "," << channel_end << ") has format " << ((int)format);
switch(format)
{
case DataParameter_CHANNELENCODING_UINT8:
for(int c=0; c<channel_count; c++)
for(int y=0; y<height; y++)
for(int x=0; x<width; x++)
*(destptr++)=static_cast<Dtype>(*(srcptr++));
break;
case DataParameter_CHANNELENCODING_UINT16FLOW:
for(int c=0; c<channel_count; c++)
for(int y=0; y<height; y++)
for(int x=0; x<width; x++)
{
short v;
*((unsigned char*)&v)=*(srcptr++);
*((unsigned char*)&v+1)=*(srcptr++);
Dtype value;
if(v==std::numeric_limits<short>::max()) {
value = std::numeric_limits<Dtype>::signaling_NaN();
} else {
value = ((Dtype)v)/32.0;
}
*(destptr++)=value;
}
break;
case DataParameter_CHANNELENCODING_BOOL1:
{
int j=0;
for(int i=0; i<(width*height-1)/8+1; i++)
{
unsigned char data=*(srcptr++);
for(int k=0; k<8; k++)
{
float value=(data&(1<<k))==(1<<k);
if(j<width*height)
*(destptr++)=value?1.0:0;
j++;
}
}
}
break;
default:
LOG(FATAL) << "Invalid format for slice " << slice;
break;
}
}
// LOG(INFO) << destptr << " " << ptr;
assert(destptr==ptr+count);
}