题意: 给你一棵n个节点的树,每个节点的颜色可能不同,现在要给你两种颜色,问你两种颜色的最大距离。如果有一种颜色不存在那么直接输出-1 即可。
思路:
先预处理出树的lca,那么求树的直径就是一个o(n)+查询lca的复杂度了。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N =1e5+5; /// 节点的个数
int rmq[N*2]; /// 就是欧拉序列对应的深度序列
struct ST
{
int mm[N*2];
int anc[2*N][20];
void init(int n)
{
mm[0]=-1;
for(int i=1;i<=n;i++){
mm[i]=((i&(i-1))==0)?mm[i-1]+1:mm[i-1];
anc[i][0]=i;
}
for(int j=1;j<=mm[n];j++){
for(int i=1;i+(1<<j)-1<=n;i++){
anc[i][j]=rmq[anc[i][j-1]] < rmq[anc[i+(1<<(j-1))][j-1]]?anc[i][j-1]:anc[i+(1<<(j-1))][j-1];
}
}
}
int query(int a,int b)
{
if(a>b) swap(a,b);
int k=mm[b-a+1];
return rmq[anc[a][k]] <= rmq[anc[b-(1<<k)+1][k]]?anc[a][k]:anc[b-(1<<k)+1][k];
}
};
struct node
{
int v,next;
}edge[N*2];
int dep[N];
int tot,head[N];
int dfns[N*2]; /// 欧拉序列,也就是dfs序, 长度为2*n-1, 下标从1 开始.
int pp[N]; /// pp[i] 表示在dfns 中第一次出现的位置
//int L[N],R[N]; /// 表示当前节点dfs序中管辖的区间看情况使用
int cnt;
int clo; /// 时钟标记 用于L,R
ST st;
void add(int u,int v)
{
edge[++tot].v=v; edge[tot].next=head[u]; head[u]=tot;
}
void dfs(int u,int fa,int deep)
{
dep[u]=deep;
dfns[++cnt]=u; rmq[cnt]=deep; pp[u]=cnt;
//L[u]=++clo;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==fa) continue;
dfs(v,u,deep+1);
dfns[++cnt]=u;
rmq[cnt]=deep;
}
//R[u]=clo;
}
void LCA_init(int rt,int node_num)
{
cnt=0;
dfs(rt,rt,0);
st.init(2*node_num-1);
}
int LCA(int u,int v)
{
return dfns[st.query(pp[u],pp[v])];
}
map<string ,int >mp;
int id[N];
int duan[N][2];
int n,m;
string s;
string s1;
void init()
{
tot=clo=cnt=0;
memset(head,-1,sizeof(head));
for(int i=1;i<=n;i++){
duan[i][0]=duan[i][1]=0;
id[i]=0;
}
mp.clear();
}
int main()
{
int u,v;
while(scanf("%d %d",&n,&m)!=EOF)
{
init();
int tot=0;
for(int i=1;i<=n;i++){
cin>>s;
if(mp[s]==0){
id[i]=mp[s]=++tot;
}
else{
id[i]=mp[s];
}
}
for(int i=1;i<n;i++){
scanf("%d %d",&u,&v);
add(u,v);
add(v,u);
}
LCA_init(1,n);
for(int ii=1;ii<=n;ii++){
int i=id[ii];
if(duan[i][0]==0&&duan[i][1]==0){
duan[i][0]=duan[i][1]=ii;
}
else{
int x=duan[i][0];
int y=duan[i][1];
int z=ii;
int len1,len2,len3;
int lca=LCA(x,y);
len1=dep[x]+dep[y]-2*dep[lca]+1;
lca=LCA(x,z);
len2=dep[x]+dep[z]-2*dep[lca]+1;
lca=LCA(y,z);
len3=dep[y]+dep[z]-2*dep[lca]+1;
if(len2>len1){
len1=len2;
duan[i][0]=x; duan[i][1]=z;
}
if(len3>len1){
duan[i][0]=y; duan[i][1]=z;
}
}
}
while(m--){
cin>>s>>s1;
if(mp[s]==0||mp[s1]==0){
printf("-1\n");
continue;
}
int i,j;
i=mp[s]; j=mp[s1];
//cout<<"i "<<i<<" j "<<j<<endl;
int a,b,c,d;
a=duan[i][0]; b=duan[i][1];
c=duan[j][0]; d=duan[j][1];
//cout<<"a "<<a<<" b "<<b<<" c "<<c<<" d "<<d<<endl;
int len1,len2,len3,len4;
len1=dep[a]+dep[c]-2*dep[LCA(a,c)]+1;
len2=dep[a]+dep[d]-2*dep[LCA(a,d)]+1;
len3=dep[b]+dep[c]-2*dep[LCA(b,c)]+1;
len4=dep[b]+dep[d]-2*dep[LCA(b,d)]+1;
//cout<<"len1 "<<len1<<" len2 "<<len2<<" len3 "<<len3<<" len4 "<<len4<<endl;
int ans=max(max(len1,len2),max(len3,len4));
printf("%d\n",ans);
}
}
return 0;
}