//
// Created by 付聪 on 2017/6/21.
//
#include <efanna2e/index_nsg.h>
#include <efanna2e/util.h>
void load_data(char* filename, float*& data, unsigned& num,unsigned& dim){// load data with sift10K pattern
std::ifstream in(filename, std::ios::binary);
if(!in.is_open()){std::cout<<"open file error"<<std::endl;exit(-1);}
in.read((char*)&dim,4);
std::cout<<"data dimension: "<<dim<<std::endl;
in.seekg(0,std::ios::end);
std::ios::pos_type ss = in.tellg();
size_t fsize = (size_t)ss;
num = (unsigned)(fsize / (dim+1) / 4);
data = new float[num * dim * sizeof(float)];
in.seekg(0,std::ios::beg);
for(size_t i = 0; i < num; i++){
in.seekg(4,std::ios::cur);
in.read((char*)(data+i*dim),dim*4);
}
in.close();
}
void save_result(char* filename, std::vector<std::vector<unsigned> > &results){
std::ofstream out(filename, std::ios::binary | std::ios::out);
for (unsigned i = 0; i < results.size(); i++) {
unsigned GK = (unsigned) results[i].size();
out.write((char *) &GK, sizeof(unsigned));
out.write((char *) results[i].data(), GK * sizeof(unsigned));
}
out.close();
}
int main(int argc, char** argv){
if(argc!=7){std::cout<< argv[0] <<" data_file query_file nsg_path search_L search_K result_path"<<std::endl; exit(-1);}
float* data_load = NULL;
unsigned points_num, dim;
load_data(argv[1], data_load, points_num, dim);
float* query_load = NULL;
unsigned query_num, query_dim;
load_data(argv[2], query_load, query_num, query_dim);
assert(dim == query_dim);
unsigned L = (unsigned)atoi(argv[4]);
unsigned K = (unsigned)atoi(argv[5]);
if(L < K){std::cout<< "search_L cannot be smaller than search_K!"<<std::endl; exit(-1);}
//data_load = efanna2e::data_align(data_load, points_num, dim);//one must align the data before build
//query_load = efanna2e::data_align(query_load, query_num, query_dim);
efanna2e::IndexNSG index(dim, points_num, efanna2e::L2, nullptr);
index.Load(argv[3]);
efanna2e::Parameters paras;
paras.Set<unsigned>("L_search", L);
paras.Set<unsigned>("P_search", L);
auto s = std::chrono::high_resolution_clock::now();
std::vector<std::vector<unsigned> > res;
for(unsigned i=0; i<query_num; i++){
std::vector<unsigned> tmp(K);
index.Search(query_load + i * dim, data_load, K, paras, tmp.data());
res.push_back(tmp);
}
auto e = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> diff = e-s;
std::cout << "search time: " << diff.count() << "\n";
save_result(argv[6], res);
return 0;
}