先思考一个问题:
- 在K维空间里面有许多的点,对于某些给定的点,我们需要找到和它最近的m个点。
- 这里的距离指的是欧几里得距离:
- D(p,q)=D(q,p)=sqrt((q1-p1)^2+(q2-p2)^2+(q3-p3)^2+...+ (qn-pn)^2),请你帮忙解决一下。
输入:
- 点数n(1≤n≤50000)和维度数k(1≤k≤5)。
- 接下来的n行,每行k个整数,代表一个点的坐标。
- 接下来一个正整数:给定的询问数量t(1≤t≤10000)
- 下面2*t行:
- 第一行k个整数,表示要查询的点的坐标
- 第二行一个整数m,表示查询最近的m个点(1≤m≤10)
- 所有坐标的绝对值不超过10000。
- 有多组数据!
输出:
- 对于每个询问,输出m+1行:
- 第一行:"the closest m points are:" m为查询中的m
- 接下来m行每行代表一个点,按照从近到远排序。
- 保证方案唯一,下面这种情况不会出现:
- 2 2
- 1 1
- 3 3
- 1
- 2 2
- 1
我们知道在二维的情况下我们可以用树状数组来解决(乱搞)。但此时题目中给出了一个会变化的维度,再用树状数组就会提高大量的思维难度(反正我是想象不出5维空间的),
此时我们就需要一种对应多维度的数据结构——KD树。
- KD树的定义:
Kd-树是K-dimension tree的缩写,是对数据点在k维空间(如二维(x,y),三维(x,y,z),k维(x1,y,z..))中划分的一种数据结构,主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索)。本质上说,Kd-树就是一种平衡二叉树。
首先必须搞清楚的是,k-d树是一种空间划分树,说白了,就是把整个空间划分为特定的几个部分,然后在特定空间的部分内进行相关搜索操作。想像一个三维(多维有点为难我的想象力了)空间,kd树按照一定的划分规则把这个三维空间划分了多个空间,如下图:
更加易懂的说法是KD树实际上就是多关键字搜索(我蒟蒻只需要知道这个就够了)。
2. KD树的构建
KD树与线段树的构建相似,需要动态递归建立
void build(int &k,int l,int r,int dir)
{
int mid=(l+r)>>1;
k=mid;D=dir;
nth_element(a+l,a+mid,a+r+1,cmp);
for(int i=0;i<K;i++)
a[k].mi[i]=a[k].mx[i]=a[k].d[i];
if(l<mid)build(a[k].l,l,mid-1,(dir+1)%K);
if(r>mid)build(a[k].r,mid+1,r,(dir+1)%K);
pushup(k);
}
3.KD树的插入
虽然此题不用插入但我们还是要学的啊
void insert(int k,int dir)
{
if (q[dir]<a[k].d[dir])
{
if (a[k].l) insert(a[k].l,(dir+1)%d);
else
{
a[k].l=++n;
for(int i=0;i<K;i++)
a[n].mi[i]=a[n].mx[i]=a[n].d[i]=q[i];
}
}
else
{
if (a[k].r) insert(a[k].r,(dir+1)%d);
else
{
a[k].r=++n;
for(int i=0;i<K;i++)
a[n].mi[i]=a[n].mx[i]=a[n].d[i]=q[i];
}
}
pushup(k);//同时向上维护
}
虽然一棵刚建好的KD树深度是O(log)的。但随便乱插会对时间有巨大的负担很容易TLE。所以我们可以用替罪羊树优化……因为博主太弱还不会(QAQ)所以请同学们自己去学习吧(学会了记得回来给我讲讲啊)!
4.KD树的查询
KD树的关键。还记得我们维护的mi[]和mx[],现在我们要用它来做估计了。我们都知道估计可以省下大量的计算,所以这也是KD树独特的地方。但我们的答案不能是估计啊!所以精确的也不能少(QAQ)
long long Guess(int k) //估算与k点的距离值
{
long long i,s=0;
for(i=0;i<K;i++)
{
if(q[i]<a[k].mi[i])s+=(long long)(q[i]-a[k].mi[i])*(q[i]-a[k].mi[i]);
if(q[i]>a[k].mx[i])s+=(long long)(q[i]-a[k].mx[i])*(q[i]-a[k].mx[i]);
}
return s;
}
long long Dis(int k) //求查询点与k点的距离值
{
long long i,ans=0;
for(i=0;i<K;i++)ans+=(long long)(q[i]-a[k].d[i])*(q[i]-a[k].d[i]);
return ans;
}
void Query(int x)
{
if(!x)return;
long long dis=Dis(x),dl=Guess(a[x].l),dr=Guess(a[x].r);
if(dis<Q.top().first)//为本题需要而建的大根堆
{
Q.pop();
Q.push(make_pair(dis,x));
}
if(dl<dr)
{
if(dl<Q.top().first)Query(a[x].l);
if(dr<Q.top().first)Query(a[x].r);
}
else
{
if(dr<Q.top().first)Query(a[x].r);
if(dl<Q.top().first)Query(a[x].l);
}
}
原题代码:
#include<bits/stdc++.h>
#define INF 0x3f3f3f3
using namespace std;
typedef pair<long long,int>pii;
priority_queue<pii>Q;
struct data{int d[6],mx[6],mi[6],l,r;}a[100005<<1];
int q[6],i,j,k,m,n,rt,D,K,t;
bool cmp(data x,data y){return x.d[D]<y.d[D];}
void pushup(int x)
{
int i,ls=a[x].l,rs=a[x].r;
for(i=0;i<K;i++)
{
if(ls)
{
a[x].mx[i]=max(a[x].mx[i],a[ls].mx[i]);
a[x].mi[i]=min(a[x].mi[i],a[ls].mi[i]);
}
if(rs)
{
a[x].mx[i]=max(a[x].mx[i],a[rs].mx[i]);
a[x].mi[i]=min(a[x].mi[i],a[rs].mi[i]);
}
}
}
void build(int &k,int l,int r,int dir)
{
int mid=(l+r)>>1;
k=mid;D=dir;
nth_element(a+l,a+mid,a+r+1,cmp);
for(int i=0;i<K;i++)
a[k].mi[i]=a[k].mx[i]=a[k].d[i];
if(l<mid)build(a[k].l,l,mid-1,(dir+1)%K);
if(r>mid)build(a[k].r,mid+1,r,(dir+1)%K);
pushup(k);
}
long long Guess(int x) //估算与x点的距离值
{
long long i,s=0;
for(i=0;i<K;i++)
{
if(q[i]<a[x].mi[i])s+=(long long)(q[i]-a[x].mi[i])*(q[i]-a[x].mi[i]);
if(q[i]>a[x].mx[i])s+=(long long)(q[i]-a[x].mx[i])*(q[i]-a[x].mx[i]);
}
return s;
}
long long Dis(int x) //求查询点与x点的距离值
{
long long i,ans=0;
for(i=0;i<K;i++)ans+=(long long)(q[i]-a[x].d[i])*(q[i]-a[x].d[i]);
return ans;
}
void Query(int x)
{
if(!x)return;
long long dis=Dis(x),dl=Guess(a[x].l),dr=Guess(a[x].r);
if(dis<Q.top().first)
{
Q.pop();
Q.push(make_pair(dis,x));
}
if(dl<dr)
{
if(dl<Q.top().first)Query(a[x].l);
if(dr<Q.top().first)Query(a[x].r);
}
else
{
if(dr<Q.top().first)Query(a[x].r);
if(dl<Q.top().first)Query(a[x].l);
}
}
void print() //从小到大输出m个点
{
int i,x;
while(!Q.empty())
{
x=Q.top().second;Q.pop();
print();
for(i=0;i<K;i++)printf("%d ",a[x].d[i]);
printf("\n");
}
}
int main()
{
while(scanf("%d%d",&n,&K)!=EOF)
{
memset(a,0,sizeof(a));
while(!Q.empty())Q.pop();//清空堆
for(i=1;i<=n;i++) //读入n个点的坐标
for(j=0;j<K;j++)
scanf("%d",&a[i].d[j]);
build(rt,1,n,0);
scanf("%d",&t); //建立KD树 scanf("%d",&t);
for(i=1;i<=t;i++) //t组询问
{
for(j=0;j<K;j++)scanf("%d",&q[j]);//读入查询点的坐标
scanf("%d",&m);
for(j=1;j<=m;j++)Q.push(make_pair(INF,0));//把k个INF加入大根堆
Query(rt);
printf("the closest %d points are:\n",m);
print();
}
}
return 0;
}