[BZOJ4539][Hnoi2016]树(主席树+lca 大模拟系列)

Description

小 A 想做一棵很大的树,但是他手上的材料有限,只好用点小技巧了。开始,小 A 只有一棵结点数为 N 的树,结点的编号为 1 , 2 , , N ,其中结点 1 为根;我们称这颗树为模板树。小 A 决定通过这棵模板树来构建一颗大树。构建过程如下:
(1)将模板树复制为初始的大树。
(2)以下 (2.1)(2.2)(2.3) 步循环执行 M
(2.1)选择两个数字 a , b ,其中 1 a N 1 b 当前大树的结点数 。
(2.2)将模板树中以结点 a 为根的子树复制一遍,挂到大树中结点 b 的下方(也就是说,模板树中的结点 a 为根的子树复制到大树中后,将成为大树中结点 b 的子树)。
(2.3)将新加入大树的结点按照在模板树中编号的顺序重新编号。例如,假设在进行(2.2)步之前大树有 L 个结点,模板树中以 a 为根的子树共有 C 个结点,那么新加入模板树的 C 个结点在大树中的编号将是 L + 1 L + 2 L + C ;大树中这 C 个结点编号的大小顺序和模板树中对应的 C 个结点的大小顺序是一致的。下面给出一个实例。假设模板树如下图:
这里写图片描述
根据第(1)步,初始的大树与模板树是相同的。在(2.1)步,假设选择了 a = 4 b = 3 。运行(2.2)和(2.3)后,得到新的大树如下图所示:这里写图片描述
现在他想问你,树中一些结点对的距离是多少。

Input

第一行三个整数: N , M , Q ,以空格隔开, N 表示模板树结点数, M 表示第(2)中的循环操作的次数, Q 表示询问数量。接下来 N 1 行,每行两个整数 f r , t o ,表示模板树中的一条树边。再接下来 M 行,每行两个整数 x , t o ,表示将模板树中 x 为根的子树复制到大树中成为结点 t o 的子树的一次操作。再接下来 Q 行,每行两个整数 f r , t o ,表示询问大树中结点 f r t o 之间的距离是多少。 N , M , Q 100000

Output

输出 Q 行,每行一个整数,第 i 行是第 i 个询问的答案。

Sample Input

5 2 3
1 4
1 3
4 2
4 5
4 3
3 2
6 9
1 8
5 3

Sample Output

6
3
3

HINT

经过两次操作后,大树变成了下图所示的形状:
这里写图片描述
结点 6 9 之间经过了 6 条边,所以距离为 6 ;类似地,结点 1 8 之间经过了 3 条边;结点 5 3 之间也经过了 3 条边。

Solution

一道特别麻烦、不好实现、毒瘤的大模拟

Part 1 - 将大树的规模压缩到线性

由题意,每次插入大树中的子树一定是模板树中一个节点的子树。
因此,可以把大树分成 M + 1 个部分,第 1 部分为初始的大树(模板树),第 i 2 i M + 1 )个部分表示第 i 1 次插入的子树。
这样大树中的一个节点可以用一个二元组 ( x , y ) 表示,表示这个节点在第 x 个部分,对应模板树中的第 y 个节点。

Part 2 - 将大树中的节点编号转化为二元组表示

对于这个问题,找到一个点 u 位于哪个部分比较容易,记下每个部分包含节点数的前缀和后二分查找即可得到 x
而现在要求 y ,就相当于在 x 在模板树中对应的子树内查询编号第 k 大的节点。
求出模板树的 dfs 序列之后,就是区间第 k 小问题,用主席树即可解决。
这样,就能在 log 的复杂度内将大树中的节点编号 u 转化为二元组表示 ( x , y )

Part 3 - 查询距离

查询两点 ( ( u , a ) ( v , b ) ) (均为二元组表示)之间的距离,可以预处理出每个部分的根到大树根节点的距离,这样就能求得 节点 ( u , a ) ( v , b ) 分别到根的距离,分别记作 d 1 , d 2 ,然后在求出 ( u , a ) ( v , b ) 的 lca 为 ( w , c ) ,那么同样地求出 ( w , c ) 到根的距离 d 3 ,那么询问答案就是:

d 1 + d 2 2 × d 3

而现在的关键问题就是求得 lca 。
w 比较显然。所有 M + 1 个部分形成了一个树的结构,在这棵树上求得 u 部分和 v 部分的 lca 就是 w
c 则要分别求出:
a x :点 ( u , a ) 不断往上跳,第一次跳进 w 部分时到达的节点即为 ( w , a x )
b x :点 ( v , b ) 不断往上跳,第一次跳进 w 部分时到达的节点即为 ( w , b x )
这时候, c 就是 模板树 a x b x 的 lca 。

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)
#define Tree(u) for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
#define Hnoi2016(u) for (int e = adj2[u], v = go2[e]; e; e = nxt2[e], v = go2[e])
using namespace std;
typedef long long ll;
inline int read() {
    int res = 0; bool bo = 0; char c;
    while (((c = getchar()) < '0' || c > '9') && c != '-');
    if (c == '-') bo = 1; else res = c - 48;
    while ((c = getchar()) >= '0' && c <= '9')
        res = (res << 3) + (res << 1) + (c - 48);
    return bo ? ~res + 1 : res;
}
inline ll readll() {
    ll res = 0; bool bo = 0; char c;
    while (((c = getchar()) < '0' || c > '9') && c != '-');
    if (c == '-') bo = 1; else res = c - 48;
    while ((c = getchar()) >= '0' && c <= '9')
        res = (res << 3) + (res << 1) + (c - 48);
    return bo ? ~res + 1 : res;
}
const int N = 1e5 + 5, M = N << 1, Z = 20, L = 4e6 + 5;
int n, m, q, ecnt, nxt[M], adj[N], go[M], dep[N], fa[N][Z], ecnt2, nxt2[N],
adj2[N], go2[N], dep2[N], fa2[N][Z], dfn[N], sze[N], T, QAQ, rt[N], le[N],
ri[N], ro[N], ft[N];
ll fk[N], dis[N];
struct cyx {
    int lc, rc, sum;
} Tr[L];
void ins(int y, int &x, int l, int r, int k) {
    Tr[x = ++QAQ] = Tr[y]; Tr[x].sum++; if (l == r) return;
    int mid = l + r >> 1; if (k <= mid) ins(Tr[y].lc, Tr[x].lc, l, mid, k);
    else ins(Tr[y].rc, Tr[x].rc, mid + 1, r, k);
}
int kth(int L, int R, int l, int r, int k) {
    if (l == r) return l; int mid = l + r >> 1, delta;
    delta = Tr[Tr[R].lc].sum - Tr[Tr[L].lc].sum;
    if (k <= delta) return kth(Tr[L].lc, Tr[R].lc, l, mid, k);
    else return kth(Tr[L].rc, Tr[R].rc, mid + 1, r, k - delta);
}
void add_edge(int u, int v) {
    nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
    nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
void add_edge2(int u, int v) {
    nxt2[++ecnt2] = adj2[u]; adj2[u] = ecnt2; go2[ecnt2] = v;
}
void dfs(int u, int fu) {
    int i; dep[u] = dep[fa[u][0] = fu] + (sze[u] = 1); dfn[u] = ++T;
    ins(rt[T - 1], rt[T], 1, n, u);
    For (i, 0, 16) fa[u][i + 1] = fa[fa[u][i]][i];
    Tree(u) if (v != fu) dfs(v, u), sze[u] += sze[v];
}
int LCA(int u, int v) {
    int i; if (dep[u] < dep[v]) swap(u, v); Rof (i, 17, 0) {
        if (dep[fa[u][i]] >= dep[v]) u = fa[u][i];
        if (u == v) return u;
    }
    Rof (i, 17, 0) if (fa[u][i] != fa[v][i])
        u = fa[u][i], v = fa[v][i]; return fa[u][0];
}
int dist(int u, int v) {return dep[u] + dep[v] - (dep[LCA(u, v)] << 1);}
void dfsHN(int u, int fu) {
    int i; dep2[u] = dep2[fa2[u][0] = fu] + 1;
    For (i, 0, 16) fa2[u][i + 1] = fa2[fa2[u][i]][i]; Hnoi2016(u)
        dis[v] = 1 + dist(ft[v], ro[u]) + dis[u], dfsHN(v, u);
}
int lcaHN(int u, int v) {
    int i; if (dep2[u] < dep2[v]) swap(u, v); Rof (i, 17, 0) {
        if (dep2[fa2[u][i]] >= dep2[v]) u = fa2[u][i];
        if (u == v) return u;
    }
    Rof (i, 17, 0) if (fa2[u][i] != fa2[v][i])
        u = fa2[u][i], v = fa2[v][i]; return fa2[u][0];
}
int whichTr(ll x, int c) {return lower_bound(fk + 1, fk + c + 1, x) - fk;}
ll distHN(int x, int y) {return dist(y, ro[x]) + dis[x];}
int jumpHN(int x, int k) {
    int i; Rof (i, 17, 0) if ((k >> i) & 1) x = fa2[x][i]; return x;
}
int main() {
    int i, x, y; ll fr, to; n = read(); m = read() + 1; q = read();
    For (i, 1, n - 1) x = read(), y = read(), add_edge(x, y);
    dfs(1, 0); ri[ro[1] = 1] = fk[le[1] = 1] = n; For (i, 2, m) {
        x = read(); to = readll(); int p = whichTr(to, i - 1);
        ft[i] = kth(rt[le[p] - 1], rt[ri[p]], 1, n, to - fk[p - 1]);
        le[i] = dfn[ro[i] = x]; ri[i] = dfn[x] + sze[x] - 1;
        fk[i] = fk[i - 1] + ri[i] - le[i] + 1; add_edge2(p, i);
    }
    dfsHN(1, 0); while (q--) {
        fr = readll(); to = readll(); int u = whichTr(fr, m), v = whichTr(to, m);
        int a = kth(rt[le[u] - 1], rt[ri[u]], 1, n, fr - fk[u - 1]),
            b = kth(rt[le[v] - 1], rt[ri[v]], 1, n, to - fk[v - 1]);
        ll sp = distHN(u, a) + distHN(v, b); int w = lcaHN(u, v);
        if (u != w) a = ft[jumpHN(u, dep2[u] - dep2[w] - 1)];
        if (v != w) b = ft[jumpHN(v, dep2[v] - dep2[w] - 1)];
        printf("%lld\n", sp - (distHN(w, LCA(a, b)) << 1));
    } 
    return 0;
}

猜你喜欢

转载自blog.csdn.net/xyz32768/article/details/80378458