题解
对每一个点建线段树 我更喜欢叫主席树
操作1:合并x和它的子树,直链用主席树(线段树)来表示位置之间的关系
操作2:
- 如果 x,y 是均不处在直链中,
- 如果 x,y 处于同一条直链内,ans = 主席树内x,y之间相隔了多少个点
- 如果 x,,y 在两条不同的直链上,
#include <bits/stdc++.h>
using namespace std;
#define mid ((l+r)>>1)
const int N=2e5+10;
vector<int>e[N];
int n,m,k;
int flag[N];//flag[u] 表示以u为根节点的树是否变成了直链
int fa[N][20],fb[N];//lca 并查集
int tot=0;
int rt[N*40],ls[N*40],rs[N*40],sum[N*40];//rt[u]表示以u为根节点的树 ls左子树 rs右子树 sum 树的大小
int newNode(){
++tot;
ls[tot]=rs[tot]=sum[tot]=0;
return tot;
}
int build(int& p,int l,int r,int x){//建立主席树(线段树) rt l r pos
p=newNode();
if(l==r){
sum[p]=1;
return p;
}
if(x<=mid){
build(ls[p],l,mid,x);
}else{
build(rs[p],mid+1,r,x);
}
sum[p]=1;//统计子树的个数
}
int depth[N];
void dfs(int p,int f,int dep){
flag[p]=0;//判断有没有被拉成直链
build(rt[p],1,n,p);
fa[p][0]=f;
depth[p]=dep;
for (int i = 1; i <=17; ++i) {//lca核心步骤
fa[u][i]=fa[fa[u][i-1]][i-1];
}
for (int i = 0; i <e[p].size(); ++i) {
int v=e[p][i];
if(v!=f)dfs(v,p,dep+1);
}
}
int Union(int u,int v,int l,int r){//合并主席树
if(u==0||v==0)return u+v;//如果有一颗树是空的 直接返回另外一颗
int p=newNode();//合并后的新树 放弃了原来的两颗
if(l==r){
sum[p]=sum[u]+sum[v];//合并点
return p;//返回新建的根节点
}
ls[p]=Union(ls[u],ls[v],l,mid);//主席树左边合并
rs[p]=Union(rs[u],rs[v],mid+1,r);//主席树右边合并
sum[p]=sum[rs[p]]+sum[ls[p]];
return p;
}
void dfs1(int p,int fg){//变成直链
if(flag[p]){//如果以p为根节点的子树已经变为直链
fb[p]=fg;//直接更改并查集
return;
}
fb[p]=fg;//更新祖先 祖先是链头
flag[p]=1;//被拉成直链的记号
for (int i = 0; i <e[p].size(); ++i) {
int v=e[p][i];
if(v!=fa[p][0]){//确保不会返回
dfs1(v,fg);
rt[p]=Union(rt[p],rt[v],1,n);//两颗主席树合并
}
}
}
int find(int u){
return fb[u]==u?u:fb[u]=find(fb[u]);
}
int query(int p,int l,int r,int x,int y){//x<y 查询x-y之间的距离
if(x==l && y==r){
return sum[p];
}
int res=0;
if(x<=mid)res+=query(ls[o],l,mid,x,y);
if(y>mid)res+=query(rs[o],mid+1,r,x,y);
return res;
}
int getlca(int x,int y){
if(depth[x]<depth[y])swap(x,y);
for (int i = 17; i >=0; --i) {
if((1<<i)<=depth[x]-depth[y]){//跳跃的距离大于两点间的深度之差 就跳 等于if(depth[f[x][i]]>=depth[y])
x=fa[x][i];
}
}
if(x==y)return x;
for (int i = 17; i>=0 ; --i) {
if(fa[x][i]!=fa[y][i]){
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
int main(){
ios::sync_with_stdio(0);
int T;
cin>>T;
for (int cs = 1,u,v; cs <= T; ++cs) {
cin>>n;
//init
tot=0;
for (int i = 1; i <= n; ++i) {
e[i].clear();
fb[i]=i;//并查集
}
memset(fa, 0, sizeof(fa));
for (int i = 1; i < n; ++i) {
cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,0,0);
int f;
cin>>m;
for (int i = 1,ans; i <= m; ++i) {
cin>>f;
if(f==1){
cin>>u;
if(!flag[u]) //如果没有调整过
dfs1(u,u);
}else{
cin>>u>>v;
int x=find(u);
int y=find(v);
if(x==y){//在同一根链上
if(u<v)swap(u,v);//假设u比v大
ans=query(rt[x],1,n,v,u)-1;//从rt[x]这颗以x为节点的主席树上查询u,v之间的距离
}else{
int lca=getlca(x,y);//找到共同的祖先
ans=depth[x]+depth[y]-depth[lca]*2;
ans+=sum[rt[x]]-query(rt[x],1,n,1,u)+sum[rt[y]]-query(rt[y],1,n,1,v);
}
cout<<ans<<endl;
}
}
}
return 0;
}