最近公共祖先(LCA) 学习笔记

最近公共祖先(LCA)

顾名思义,就是两个结点的最近公共父结点;

这里运用了倍增的思路,当你想要一个个的往上递增时,复杂度很高,可以把要递增的数量换成二进制表示(因为任意一个数都可以由二进制数相加得到);

比如11,可以由 2^3=8,2 ^1=2, 2 ^0=1,相加而成;所以11这个数可以先加8,在加2,在加1得到;

这里有两个预处理,一个是得到每个结点所在的层数,一个是得到每个结点往上走2的 j 次方次得到的结点;

这两个操作可以放在一个dfs函数里面,应该非常好理解:

fa[sn][i]=fa[fa[sn][i-1]][i-1];可以手动模拟,大致意思为2^n=2 ^n-1 * 2 ^ n-1;

void dfs(int sn,int ft){
	dep[sn]=dep[ft]+1,fa[sn][0]=ft;
	for(int i=1;i<=lg[dep[sn]];i++) fa[sn][i]=fa[fa[sn][i-1]][i-1];
	for(int i=head[sn];~i;i=edge[i].nex){
		if(edge[i].to!=ft) dfs(edge[i].to,sn);
	}
}

这里还学到一种预处理log_2(i)+1的方法:

for(int i=1;i<=n;i++) lg[i]=lg[i-1]+(1<<lg[i-1]==i);//预处理(log_2(i))+1的值 

这个东西就是求一个数要用二进制数组成的最大次数:

比如11,它的lg[11]=4,我们在调用lg[11]的时候还要减去1,也就是说要组成11,最多只要3次;

然后核心的 lca 代码主要的思路就是:先使 x 和 y 在同一层,然后在一起往上走,找到最顶部的两个结点 x 和 y ,结点 x 和 y 的父节点相等,那么 x 的父节点就是答案;

int lca(int x,int y){
	if(dep[x]<dep[y]) swap(x,y);
	while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]-1];
	if(x==y) return x;
	for(int k=lg[dep[x]]-1;k>=0;k--){
		if(fa[x][k]!=fa[y][k]){
			x=fa[x][k],y=fa[y][k];
		}
	}
	return fa[x][0];
}

全部题目来做这道题:
【模板】最近公共祖先(LCA)

全部代码:

#include<bits/stdc++.h>
#define LL long long
#define pa pair<int,int>
#define ls k<<1
#define rs k<<1|1
#define inf 0x3f3f3f3f
using namespace std;
const int N=500100;
const int M=1000100;
const LL mod=100000000;
int n,m,s,head[N],cnt,lg[N],dep[N],fa[N][40];
struct Node{
	int to,nex;
}edge[M];
void add(int p,int q){
	edge[cnt].to=q;
	edge[cnt].nex=head[p];
	head[p]=cnt++;
}
void dfs(int sn,int ft){
	dep[sn]=dep[ft]+1,fa[sn][0]=ft;
	for(int i=1;i<=lg[dep[sn]];i++) fa[sn][i]=fa[fa[sn][i-1]][i-1];
	for(int i=head[sn];~i;i=edge[i].nex){
		if(edge[i].to!=ft) dfs(edge[i].to,sn);
	}
}
int lca(int x,int y){
	if(dep[x]<dep[y]) swap(x,y);
	while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]-1];
	if(x==y) return x;
	for(int k=lg[dep[x]]-1;k>=0;k--){
		if(fa[x][k]!=fa[y][k]){
			x=fa[x][k],y=fa[y][k];
		}
	}
	return fa[x][0];
}
int main(){
//    ios::sync_with_stdio(false);
    memset(head,-1,sizeof(head));
    cin>>n>>m>>s;
	for(int i=1;i<n;i++){
		int p,q;
		scanf("%d%d",&p,&q);
		add(p,q),add(q,p);
	} 
	for(int i=1;i<=n;i++) lg[i]=lg[i-1]+(1<<lg[i-1]==i);//预处理(log_2(i))+1的值 
	dfs(s,0);
	while(m--){
		int a,b;
		scanf("%d%d",&a,&b);
		printf("%d\n",lca(a,b));
	}
    return 0;
}
发布了264 篇原创文章 · 获赞 46 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_44291254/article/details/104886620