[洛谷3806] 点分治1

题目背景

感谢hzwer的点分治互测。

题目描述

给定一棵有n个点的树

询问树上距离为k的点对是否存在。

输入输出格式

输入格式:

n,m 接下来n-1条边a,b,c描述a到b有一条长度为c的路径

接下来m行每行询问一个K

输出格式:

对于每个K每行输出一个答案,存在输出“AYE”,否则输出”NAY”(不包含引号)

输入输出样例

输入样例#1:  复制
2 1
1 2 2
2
输出样例#1:  复制
AYE

说明

对于30%的数据n<=100

对于60%的数据n<=1000,m<=50

对于100%的数据n<=10000,m<=100,c<=1000,K<=10000000

跟BZOJ1316是一样的,但是辣鸡bzoj他卡我常。。

这道题有两种思路,一种是把所有的路径长度全用点分治刷出来O(1)回答询问,但好像我怎么想都是O(n^2)的,还不如直接枚举两点求LCA,另一种则是对每个询问求,时间复杂度是O(q*n*logn),然而n和q都比较小,就跑过去了。

代码:

  1 #include<iostream>
  2 #include<cstdio>
  3 #include<cstring>
  4 #include<algorithm>
  5 #define M 10010
  6 #define inf 1e8
  7 using namespace std;
  8 struct point{
  9     int next,to,dis;
 10 }e[M<<1];
 11 int n,m,num,S,maxn,root,K,ans,q;
 12 int head[M],dis[M],maxsize[M],size[M];
 13 bool vis[M];
 14 void add(int from,int to,int dis)
 15 {
 16     e[++num].next=head[from];
 17     e[num].to=to;
 18     e[num].dis=dis;
 19     head[from]=num;
 20 }
 21 void getroot(int x,int fa)
 22 {
 23     size[x]=1; maxsize[x]=0;
 24     for(int i=head[x];i;i=e[i].next)
 25     {
 26         int to=e[i].to;
 27         if(vis[to]||to==fa) continue;
 28         getroot(to,x);
 29         maxsize[x]=max(maxsize[x],size[to]);
 30         size[x]+=size[to];
 31     }
 32     maxsize[x]=max(maxsize[x],S-size[x]);
 33     if(maxsize[x]<maxn) maxn=maxsize[x],root=x;
 34 }
 35 void getdis(int x,int fa,int len)
 36 {
 37     dis[++num]=len;
 38     for(int i=head[x];i;i=e[i].next)
 39     {
 40         int to=e[i].to;
 41         if(vis[to]||to==fa) continue;
 42         getdis(to,x,len+e[i].dis);
 43     }
 44 }
 45 int cal(int x,int len)
 46 {
 47     num=0;
 48     getdis(x,0,len);
 49     sort(dis+1,dis+1+num);
 50     int l=1,r=num,suml=0,sumr=0,tot=0;
 51     while(l<r)
 52     {
 53         if(dis[l]+dis[r]<K) l++;
 54         else
 55         {
 56             if(dis[l]+dis[r]==K)
 57             {
 58                 suml=1,sumr=1;
 59                 while(dis[l+1]==dis[l]&&l+1<r) l++,suml++;
 60                 while(dis[r-1]==dis[r]&&r-1>l) r--,sumr++;
 61                 tot+=suml*sumr;
 62                 l++;
 63             }
 64             r--;
 65         }
 66     }
 67     return tot;
 68 }
 69 void solve(int x)
 70 {
 71     ans+=cal(x,0);
 72     vis[x]=true;
 73     for(int i=head[x];i;i=e[i].next)
 74     {
 75         int to=e[i].to;
 76         if(vis[to]) continue;
 77         ans-=cal(to,e[i].dis);
 78         S=size[to]; root=0; maxn=inf;
 79         getroot(to,0); solve(root);
 80     }
 81 }
 82 int main()
 83 {
 84     scanf("%d%d",&n,&q);
 85     for(int i=1;i<n;i++)
 86     {
 87         int x,y,z; scanf("%d%d%d",&x,&y,&z);
 88         add(x,y,z); add(y,x,z);
 89     }
 90     for(int i=1;i<=q;i++)
 91     {
 92         scanf("%d",&K);
 93         S=n; maxn=inf; getroot(1,0); ans=0;
 94         memset(vis,false,sizeof(vis));
 95         solve(root);
 96         if(ans) printf("AYE\n");
 97         else printf("NAY\n");
 98     }
 99     return 0;
100 }

猜你喜欢

转载自www.cnblogs.com/Slrslr/p/9385929.html