https://www.luogu.com.cn/problem/SP10707
欧拉序就是进去的时候标记一次出来的时候标记一次,然后我们把这个序列拿出来,变成一个2*n的新序列,每个点有前后两次访问ll[i]和rr[i],那么一段树上的路径就可以用区间表示
假设询问是u,v的路径,ll[u]<ll[v],如果u是v的祖先,那么直接对应到ll[u]-ll[v],否则对应到rr[u]-ll[v]再加上ll[lca(u,v)],因为u和v的最近公共祖先出现一定在ll[u]左边,第二次出现在rr[v]右边,所以是不包含在当前区间里的,然后注意如果这段欧拉序序列中某个位置出现了两次,那么他也是不在u,v路径上的
#include<bits/stdc++.h>
using namespace std;
const int maxl=1e5+10;
int n,m,tot,len;
int a[maxl],b[maxl],idx[maxl],bel[maxl],ans[maxl],dep[maxl];
int ll[maxl],rr[maxl],num[maxl],vis[maxl];
int f[21][maxl];
vector<int> e[maxl];
struct qry
{
int l,r,lca,id;
}q[maxl];
inline void dfs(int u,int fa)
{
idx[++n]=u;ll[u]=n;
for(int v:e[u])
if(v!=fa)
{
f[0][v]=u;dep[v]=dep[u]+1;
dfs(v,u);
}
idx[++n]=u;rr[u]=n;
}
inline int getlca(int u,int v)
{
if(dep[u]<dep[v])
swap(u,v);
for(int i=20;i>=0;i--)
if((dep[u]-dep[v])>>i&1)
u=f[i][u];
if(u==v)
return u;
for(int i=20;i>=0;i--)
if(f[i][u]!=f[i][v])
u=f[i][u],v=f[i][v];
return f[0][u];
}
inline bool cmp(const qry&a,const qry&b)
{
return (bel[a.l]^bel[b.l])?bel[a.l]<bel[b.l]:a.r<b.r;
}
inline void prework()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]),b[i]=a[i];
sort(b+1,b+1+n);
tot=unique(b+1,b+1+n)-b-1;
for(int i=1;i<=n;i++)
a[i]=lower_bound(b+1,b+1+tot,a[i])-b;
int u,v;
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&u,&v);
e[u].push_back(v);
e[v].push_back(u);
}
n=0;
dfs(1,0);
len=sqrt(n);
for(int i=1;i<=n;i++)
bel[i]=(i-1)/len+1;
for(int k=1;k<=20;k++)
for(int i=1;i<=n;i++)
f[k][i]=f[k-1][f[k-1][i]];
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i].l,&q[i].r),q[i].id=i;
if(ll[q[i].l]>ll[q[i].r])
swap(q[i].l,q[i].r);
q[i].lca=getlca(q[i].l,q[i].r);
if(q[i].lca==q[i].l)
{
q[i].l=ll[q[i].l],q[i].r=ll[q[i].r];
q[i].lca=0;
}
else
{
q[i].l=rr[q[i].l],q[i].r=ll[q[i].r];
q[i].lca=ll[q[i].lca];
}
}
sort(q+1,q+1+m,cmp);
}
inline int solv(int i)
{
vis[idx[i]]^=1;
if(vis[idx[i]])
return !num[a[idx[i]]]++;
else
return -(!--num[a[idx[i]]]);
}
inline void mainwork()
{
int l=1,r=0,now=0;
for(int i=1;i<=m;i++)
{
while(r<q[i].r)
++r,now+=solv(r);
while(l>q[i].l)
--l,now+=solv(l);
while(r>q[i].r)
now+=solv(r),r--;
while(l<q[i].l)
now+=solv(l),l++;
if(q[i].lca)
now+=solv(q[i].lca);
ans[q[i].id]=now;
if(q[i].lca)
now+=solv(q[i].lca);
}
}
inline void print()
{
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
}
int main()
{
prework();
mainwork();
print();
return 0;
}