题目背景
感谢hzwer的点分治互测。
题目描述
给定一棵有n个点的树
询问树上距离为k的点对是否存在。
输入输出格式
输入格式:
n,m 接下来n-1条边a,b,c描述a到b有一条长度为c的路径
接下来m行每行询问一个K
输出格式:
对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)
输入输出样例
说明
对于30%的数据n<=100
对于60%的数据n<=1000,m<=50
对于100%的数据n<=10000,m<=100,c<=1000,K<=10000000
跟BZOJ1316是一样的,但是辣鸡bzoj他卡我常。。
这道题有两种思路,一种是把所有的路径长度全用点分治刷出来O(1)回答询问,但好像我怎么想都是O(n^2)的,还不如直接枚举两点求LCA,另一种则是对每个询问求,时间复杂度是O(q*n*logn),然而n和q都比较小,就跑过去了。
代码:
1 #include<iostream> 2 #include<cstdio> 3 #include<cstring> 4 #include<algorithm> 5 #define M 10010 6 #define inf 1e8 7 using namespace std; 8 struct point{ 9 int next,to,dis; 10 }e[M<<1]; 11 int n,m,num,S,maxn,root,K,ans,q; 12 int head[M],dis[M],maxsize[M],size[M]; 13 bool vis[M]; 14 void add(int from,int to,int dis) 15 { 16 e[++num].next=head[from]; 17 e[num].to=to; 18 e[num].dis=dis; 19 head[from]=num; 20 } 21 void getroot(int x,int fa) 22 { 23 size[x]=1; maxsize[x]=0; 24 for(int i=head[x];i;i=e[i].next) 25 { 26 int to=e[i].to; 27 if(vis[to]||to==fa) continue; 28 getroot(to,x); 29 maxsize[x]=max(maxsize[x],size[to]); 30 size[x]+=size[to]; 31 } 32 maxsize[x]=max(maxsize[x],S-size[x]); 33 if(maxsize[x]<maxn) maxn=maxsize[x],root=x; 34 } 35 void getdis(int x,int fa,int len) 36 { 37 dis[++num]=len; 38 for(int i=head[x];i;i=e[i].next) 39 { 40 int to=e[i].to; 41 if(vis[to]||to==fa) continue; 42 getdis(to,x,len+e[i].dis); 43 } 44 } 45 int cal(int x,int len) 46 { 47 num=0; 48 getdis(x,0,len); 49 sort(dis+1,dis+1+num); 50 int l=1,r=num,suml=0,sumr=0,tot=0; 51 while(l<r) 52 { 53 if(dis[l]+dis[r]<K) l++; 54 else 55 { 56 if(dis[l]+dis[r]==K) 57 { 58 suml=1,sumr=1; 59 while(dis[l+1]==dis[l]&&l+1<r) l++,suml++; 60 while(dis[r-1]==dis[r]&&r-1>l) r--,sumr++; 61 tot+=suml*sumr; 62 l++; 63 } 64 r--; 65 } 66 } 67 return tot; 68 } 69 void solve(int x) 70 { 71 ans+=cal(x,0); 72 vis[x]=true; 73 for(int i=head[x];i;i=e[i].next) 74 { 75 int to=e[i].to; 76 if(vis[to]) continue; 77 ans-=cal(to,e[i].dis); 78 S=size[to]; root=0; maxn=inf; 79 getroot(to,0); solve(root); 80 } 81 } 82 int main() 83 { 84 scanf("%d%d",&n,&q); 85 for(int i=1;i<n;i++) 86 { 87 int x,y,z; scanf("%d%d%d",&x,&y,&z); 88 add(x,y,z); add(y,x,z); 89 } 90 for(int i=1;i<=q;i++) 91 { 92 scanf("%d",&K); 93 S=n; maxn=inf; getroot(1,0); ans=0; 94 memset(vis,false,sizeof(vis)); 95 solve(root); 96 if(ans) printf("AYE\n"); 97 else printf("NAY\n"); 98 } 99 return 0; 100 }