题意:查询 一个数 与 树上两个结点之间的所有数中的任意一个数 的异或和的最大值。
解法:树剖求lca,可持久化字01典树解决问题
代码如下:
#include <bits/stdc++.h> using namespace std; #define N 100015 struct Edge{ int to,next; }edge[N<<1]; int head[N]; int cnt; int fa[N],son[N],siz[N],dep[N]; int dfn[N],topfa[N],tid[N]; void dfs1(int f,int u,int d){ dep[u] = d; siz[u] = 1; fa[u] = f; for(int i = head[u];i != -1;i = edge[i].next){ int v = edge[i].to; if(v == f) continue; dfs1(u,v,d+1); siz[u] += siz[v]; if(son[u] == -1 || siz[son[u]] <= siz[v]) son[u] = v; } } void dfs2(int tf,int f,int u){ dfn[u] = cnt; tid[cnt++] = u; topfa[u] = tf; for(int i = head[u];i != -1;i = edge[i].next){ int v = edge[i].to; if(v == f) continue; if(v == son[u]) dfs2(tf,u,v); else dfs2(v,u,v); } } int lca(int a,int b){ while(topfa[a] != topfa[b]){ if(dep[topfa[a]] <= dep[topfa[b]]) b = fa[topfa[b]]; else a = fa[topfa[a]]; } if(dep[a] <= dep[b]) return a; else return b; } int v[N]; int ro[N]; struct Tr{ int c[2],tot; }tr[N*30]; int newtr(){ cnt++; tr[cnt].tot = 0,tr[cnt].c[0] = 0, tr[cnt].c[1] = 0; return cnt; } void insert(int pre,int ro,int num){ for(int i = 15;i >= 0;--i){ int now = (num>>i)&1; tr[ro].c[now] = newtr(); tr[ro].c[now^1] = tr[pre].c[now^1]; if(!tr[pre].c[now]) tr[tr[ro].c[now]].tot = 1; else tr[tr[ro].c[now]].tot = tr[tr[pre].c[now]].tot + 1; pre = tr[pre].c[now]; ro = tr[ro].c[now]; } } void dfs(int fa,int u){ int pre = ro[fa]; ro[u] = newtr(); int num = v[u]; insert(pre,ro[u],num); for(int i = head[u];i != -1;i = edge[i].next){ int v = edge[i].to; if(v == fa) continue; dfs(u,v); } } int query(int num,int rou,int rofa){ int ans = 0; for(int i = 15;i >= 0;i--){ int now = (num>>i)&1; if(tr[tr[rou].c[1^now]].tot - tr[tr[rofa].c[1^now]].tot){ rou = tr[rou].c[1^now]; rofa = tr[rofa].c[1^now]; ans += (1<<i); } else{ rou = tr[rou].c[now]; rofa = tr[rofa].c[now]; } } return ans; } int main() { int n,m; tr[0].c[0] = tr[0].c[1] = 0; tr[0].tot = 0; while(~scanf("%d%d",&n,&m)){ memset(son,-1,sizeof(son)); memset(head,-1,sizeof(head)); for(int i = 1;i <= n;++i) scanf("%d",&v[i]); cnt = 0; int u,v; for(int i = 1;i < n;++i){ scanf("%d%d",&u,&v); edge[cnt].to = v,edge[cnt].next = head[u],head[u] = cnt++; edge[cnt].to = u,edge[cnt].next = head[v],head[v] = cnt++; } cnt = 0; dfs1(0,1,1),dfs2(1,1,1); cnt = 0; dfs(0,1); int x,y,z; for(int i = 1;i <= m;++i){ scanf("%d%d%d",&x,&y,&z); int f = lca(x,y); int ans = max(query(z,ro[y],ro[fa[f]]),query(z,ro[x],ro[f])); printf("%d\n",ans); } } return 0; }