lca最近公共祖先,联通分量

https://vjudge.net/contest/295298#problem/A

lca 的题目

求任意两点的距离。

A题是在线算法,用st表rmq来实现。

https://blog.csdn.net/nameofcsdn/article/details/52230548

相当于先把整个树dfs一遍,记录整个dfs过程中的点(可重复,相当于dfs序,按顺序排好所有的点),并且记录每个点第一次被遍历到的得dfs序,

然后两个点的最近公共祖先就是第一次被遍历到的下标之间点深度最小的那个点。

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn = 4e4 + 5;
int t;
int n, m;
struct Node {
    int v, dis;
    Node(int v = 0, int dis = 0) : v(v), dis(dis) {};
};
vector<Node> g[maxn];
int rmq[maxn << 1];//记录深度。整个dfs过程中遍历点得到深度。
int dfsn[maxn << 1];//记录整个dfs过程中经过的点,总的大小为2 * n - 1;
int fir[maxn];//每个点第一次被dfs到在dfsn数组中的位置。
int dis[maxn];//与根节点之间的的距离
int cnt = 0;

void dfs(int u, int pre, int dep) {
    dfsn[++cnt] = u;
    rmq[cnt] = dep;
    fir[u] = cnt;
    for(int i = 0; i < g[u].size(); i++) {
        int v = g[u][i].v;
        if(v == pre) continue;
        dis[v] = dis[u] + g[u][i].dis;
        dfs(v, u, dep + 1);
        dfsn[++cnt] = u;
        rmq[cnt] = dep;
    }
}

int lg[maxn << 1];
int dp[maxn << 1][20];

void RMQ(int val) {
    lg[0] = -1;
    for(int i = 1; i <= val; i++) {
        lg[i] = ((i & (i - 1)) == 0) ? lg[i - 1] + 1 : lg[i - 1];
        dp[i][0] = i;
    }//记录lg值。
    for(int j = 1; j <= lg[val]; j++) {
        for(int i = 1; i + (1 << j) - 1 <= val; i++) {
            dp[i][j] = rmq[dp[i][j - 1] ] < rmq[dp[i + (1 << (j - 1))][j - 1] ] ? dp[i][j - 1] : dp[i + (1 << (j - 1)) ][j - 1];

        }
    }//rmq得到最小深度的那个点的下标。
}

int query(int a, int b) {
    if(a > b) swap(a, b);
    int k = lg[b - a + 1];
    return rmq[dp[a][k] ] < rmq[dp[b - (1 << k) + 1][k] ] ? dp[a][k] : dp[b - (1 << k) + 1 ][k];
}

int lca_query(int u, int v) {
    return dfsn[query(fir[u], fir[v])];//得到深度最小的那个点对应的节点编号。
}

int main() {
scanf("%d", &t);
while(t--) {
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++) {
        g[i].clear();
        dis[i] = 0;
    }
    int u, v, w;
    for(int i = 1; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        g[u].push_back(Node(v, w));
        g[v].push_back(Node(u, w));
    }
    cnt = 0;
    dfs(1, 0, 0);
    RMQ(2 * n - 1);
    while(m--) {
        scanf("%d%d", &u, &v);
        int val = lca_query(u, v);
        printf("%d\n", dis[u] + dis[v] - 2 * dis[val]);
    }
}

return 0;
}

https://vjudge.net/contest/295298#problem/B 离线lca算法

先把每个点的父节点记为自己,然后dfs遍历,直到回溯的时候才更新该节点的父节点。

对于查询的两点,互相记录对方的编号以及这是第一次查询的编号,然后当其中某个点被遍历到,但另外一个点没有被遍历到,那个继续dfs,如果遍历一个点,他对应的那个点已经遍历过了,那个他们的

最近公共祖先就是之前那个被遍历过的点的父节点,(画个图想想,那个父节点就是他们的最近公共祖先)

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int maxn = 1e5 + 5;
typedef pair<int, int> pii;
vector<pii> g[maxn], que[maxn];
int t;
int n, m, k;
int ans[maxn], vis[maxn], pa[maxn], res[maxn], dis[maxn];
//ans数组:每个点的祖先节点编号,pa数组:每个点的父节点的编号。
int fid(int x) {
    if(x == pa[x]) return x;
    pa[x] = fid(pa[x]);
    return pa[x];
}

void unio(int x, int y) {
    int u = fid(x);
    int v = fid(y);
    if(u != v) pa[u] = v;
}

void lca(int u, int fa) {
    ans[u] = u;
    for(int i = 0; i < g[u].size(); i++) {
        int v = g[u][i].first;
        if(v == fa) continue;
        dis[v] = dis[u] + g[u][i].second;
        lca(v, u);
        unio(u, v);//连通起来
        ans[fid(u) ] = u;//记录父节点的祖先,儿子们得到祖先的时候也要先fid自己父节点一下
    }
    vis[u] = 1;
    for(int i = 0; i < que[u].size(); i++) {
        int v = que[u][i].first;
        int w = que[u][i].second;
        if(vis[v]) res[w] = dis[u] + dis[v] - 2 * dis[ans[fid(v)] ];
    }
}

int main() {
while(~scanf("%d%d", &n, &m)) {
    int u, v, w;
    char ch[3];
    for(int i = 1; i <= n; i++) {
        pa[i] = i;
        ans[i] = 0;
        vis[i] = 0;
        res[i] = 0;
        g[i].clear();
        que[i].clear();
    }
    for(int i = 1; i <= m; i++) {
        scanf("%d%d%d%s", &u, &v, &w, ch);
        g[u].push_back(pii(v, w));
        g[v].push_back(pii(u, w));
    }
    scanf("%d", &k);
    for(int i = 1; i <= k; i++) {
        scanf("%d%d", &u, &v);
        que[u].push_back(pii(v, i));
        que[v].push_back(pii(u, i));
    }
    dis[1] = 0;
    lca(1, 0);
    for(int i = 1; i <= k; i++) printf("%d\n", res[i]);
}
return 0;
}

猜你喜欢

转载自www.cnblogs.com/downrainsun/p/10712297.html