任重而道远
题目描述
n个集合 m个操作
操作:
-
1 a b
合并a,b所在集合 -
2 k
回到第k次操作之后的状态(查询算作操作) -
3 a b
询问a,b是否属于同一集合,是则输出1否则输出0
输入输出格式
输入格式:
输出格式:
输入输出样例
输入样例#1: 复制
5 6 1 1 2 3 1 2 2 0 3 1 2 2 1 3 1 2
输出样例#1: 复制
1 0 1
说明
1≤n≤1e5,1≤m≤2×1e5
By zky 出题人大神犇
AC代码:
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
using namespace std;
const int N = 2e5 + 5;
struct Node {
int fa, dep;
Node *ls, *rs;
}pool[N * 30], *tail = pool, *root[N];
int n, m;
Node *build (int l, int r) {
Node *nd = ++tail;
if (l == r) {
nd -> fa = l;
nd -> dep = 0;
return nd;
}
int mid = l + r >> 1;
nd -> ls = build (l, mid);
nd -> rs = build (mid + 1, r);
return nd;
}
Node *query (Node *rt, int o, int l, int r) {
if (l == r) return rt;
int mid = l + r >> 1;
if (o <= mid) return query (rt -> ls, o, l ,mid);
else return query (rt -> rs, o, mid + 1, r);
}
Node *find (int x, Node *nd) {
Node *f = query (nd, x, 1, n);
return f -> fa == x ? f : find (f -> fa, nd);
}
Node *merge (Node *rt, int x, int y, int l, int r) {
Node *nd = ++tail;
if (l == r) {
nd -> dep = rt ->dep;
nd -> fa = y;
return nd;
}
int mid = l + r >> 1;
if (x <= mid) {
nd -> ls = merge (rt -> ls, x, y, l, mid);
nd -> rs = rt -> rs;
} else {
nd -> ls = rt -> ls;
nd -> rs = merge (rt -> rs, x, y, mid + 1, r);
}
return nd;
}
void update (Node *&rt, int o, int l, int r) {
if (l == r) {
rt -> dep++;
return ;
}
int mid = l + r >> 1;
if (o <= mid)
update (rt -> ls, o, l, mid);
else
update (rt -> rs, o, mid + 1, r);
return ;
}
int main () {
scanf ("%d%d", &n, &m);
root[0] = build (1, n);
for (int i = 1; i <= m; i++) {
int opt;
scanf ("%d", &opt);
if (opt == 1) {
int a, b;
scanf ("%d%d", &a, &b);
Node *fa = find (a, root[i - 1]);
Node *fb = find (b, root[i - 1]);
if (fa -> fa == fb -> fa) {
root[i] = root[i - 1];
goto Nxt;
}
if (fa -> dep > fb -> dep) swap (fa, fb);
root[i] = merge (root[i - 1], fa -> fa, fb -> fa, 1, n);
if (fa -> dep == fb -> dep)
update (root[i], fb -> fa, 1, n);
} else if (opt == 2) {
int k;
scanf ("%d", &k);
root[i] = root[k];
} else {
int a, b;
scanf ("%d%d", &a, &b);
root[i] = root[i - 1];
Node *fa = find (a, root[i]);
Node *fb = find (b, root[i]);
if (fa -> fa == fb -> fa) puts ("1");
else puts ("0");
}
Nxt:;
}
return 0;
}