Splay解决区间问题
Splay是什么?
和 Treap 一样, Splay 也是一种平衡树,不同的是, Splay 没有另外记录随机生成的权值,而只是记录了结点原本的值。
Treap 的旋转是为了保持堆的性质,而 Splay 的旋转是为了让树不会退化以至于复杂度过高。一般而言,我们会在对某个结点进行操作后把它旋转到根结点。
至于为什么这样进行旋转能够让树不退化,证明过于复杂,在此不再赘述。
Splay的双旋
考虑如何把一个结点旋转到根,一个最简单的想法就是每次把它向它的父亲结点旋转,直到它成为根结点。这就是单旋。
但是单旋有一个严重的问题,比如下图中我们想把结点 X 旋转到根。
注意到,在原 Splay 中,有 Z -> Y -> X -> b… 这样一条链,而旋转过后,仍有 Z -> Y -> X -> b… 这样的链,长度和原来相同,这就意味着 Splay 可能会退化。
对于这个问题,我们有一个解决的方法——双旋。当我们要旋转的结点 X 和它的父亲是同一种儿子(都为左儿子或都为右儿子)时,先旋转它的父亲结点,再旋转结点 X ,如下图所示。
原本结点 X 的两个儿子 a 和 b ,旋转后分别是 X 和 Y 的儿子,深度至少减少了 1 。
Splay 解决区间问题
考虑如下的问题,有一个序列,我们可以在某个数之后插入一个数,或者询问某个区间的数之和。
由于需要执行插入数的操作,所以我们需要使用平衡树来完成这一题。
假设我们需要查询区间 [l, r] 之间的数之和。为了方便叙述,先假设l != 1, r != n。
我们先查询第 l - 1 个数的编号 x 和第 r + 1 个数的编号 y 。然后将结点 x 旋转到根结点,将结点 y 旋转到结点 x 的右儿子处,那么 y 的左子树就是所有区间 [l, r] 之间的数。
我们只需要对于每个结点维护子树权值和即可。
Splay的实现
#include <cstdio>
int n, ty, pos, l, r, val;
long long ans;
struct node {
node* son[2];
node* fa;
int val, size;
long long sum;
} * root;
void update(node* x) {
x->sum = x->val, x->size = 1;
if (x->son[0]) x->sum += x->son[0]->sum, x->size += x->son[0]->size;
if (x->son[1]) x->sum += x->son[1]->sum, x->size += x->son[1]->size;
}
void rotate(node* x, int dir) {
node* y = x->fa;
x->fa = y->fa;
if (y->fa != NULL) {
if (y == y->fa->son[0])
y->fa->son[0] = x;
else
y->fa->son[1] = x;
}
if (x->son[1 - dir] != NULL) x->son[1 - dir]->fa = y;
y->son[dir] = x->son[1 - dir];
x->son[1 - dir] = y;
y->fa = x;
update(y), update(x);
}
void splay(node* x, node* fa) {
node* y;
while (x->fa != fa) {
y = x->fa;
if (y->fa == fa) {
if (x == y->son[0])
rotate(x, 0);
else
rotate(x, 1);
} else {
if (x == y->son[0])
if (y == y->fa->son[0])
rotate(y, 0), rotate(x, 0);
else
rotate(x, 0), rotate(x, 1);
else if (y == y->fa->son[0])
rotate(x, 1), rotate(x, 0);
else
rotate(y, 1), rotate(x, 1);
}
}
}
node* kth(node* x, int k) {
if (x->son[0] != NULL) {
if (k <= x->son[0]->size)
return kth(x->son[0], k);
else
k -= x->son[0]->size;
}
if (k == 1)
return x;
else
k--;
return kth(x->son[1], k);
}
node* insert(node* x, int pos, int val) {
if (x->son[0] != NULL) {
if (pos <= x->son[0]->size) {
node* y = insert(x->son[0], pos, val);
update(x);
return y;
} else
pos -= x->son[0]->size;
}
if (pos == 0) {
node* y = new node;
y->sum = y->val = val;
y->size = 1;
y->fa = x;
x->son[0] = y;
update(x);
return y;
}
if (x->son[1] != NULL) {
node* y = insert(x->son[1], pos - 1, val);
update(x);
return y;
} else {
node* y = new node;
y->sum = y->val = val;
y->size = 1;
y->fa = x;
x->son[1] = y;
update(x);
return y;
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &ty);
if (ty == 0) {
scanf("%d%d", &pos, &val);
if (root == NULL) {
root = new node;
root->fa = root->son[0] = root->son[1] = NULL;
root->sum = root->val = val;
root->size = 1;
} else {
node* x = insert(root, pos, val);
splay(x, NULL);
root = x;
}
} else {
scanf("%d%d", &l, &r);
if (l == 1) {
if (r == root->size)
ans = root->sum;
else {
node* y = kth(root, r + 1);
splay(y, NULL);
root = y;
ans = y->son[0]->sum;
}
} else {
node* x = kth(root, l - 1);
splay(x, NULL);
root = x;
if (r == root->size)
ans = x->son[1]->sum;
else {
node* y = kth(root, r + 1);
splay(y, x);
ans = y->son[0]->sum;
}
}
printf("%lld\n", ans);
}
}
return 0;
}
Treap 和 Splay 的延迟标记
我们把之前的问题再增加难度,变成如下问题。
给定一个初始为空的序列,有三种操作
- 在第 pos 个数之后插入一个数 val。
- 查询区间 [l, r] 之间的数的和。
- 把区间 [l, r] 上的数都增加 。
我们仍然考虑用平衡树来解决这个问题。与之前的问题区别在于增加了“区间加”这一操作。若果我们对每个点单独进行修改,那么时间复杂度到达了O(n log n),这和建树的复杂度一样,显然不是我们想要的结果。
延迟标记
区间修改和区间查询的思想类似。
区间 [l, r] 可以划分为平衡树上的某些结点,显然这些结点在平衡树的中序遍历上是连续的。
如果在修改时对于每个结点都进行修改,实际上区间修改的复杂度仍然和建树的复杂度一样。所以我们利用一种特殊的方法来解决——延迟标记。
延迟标记又称为 lazy-tag ,即对于一个子树,我们在根结点打上一个标记,比如”这个区间的所有结点权值都要加上 val,而不是立即递归修改所有结点。同时,因为这棵子树中每个数都加上 val,所以这个子树中所有数的和增加了子树中结点个数个val,这样我们可以保证所有子树根结点的信息是正确的。
同样,在子树向上传递信息的过程中,也可以保证它祖先结点信息的正确性。不过,这棵子树的非根结点的信息并不能保证正确性。
对于 Treap ,我们需要递归修改,如果某个子树都要修改,那么直接在根结点打上延迟标记。而对于 Splay ,我们可以用旋转把区间放在一个子树内,这样只打一个延迟标记即可。
向下传递标记
虽然子树非根结点的信息不能保证正确性,但是我们暂时不需要用到这些结点。
我们只需要保证用到这些结点时它们时正确的。
我们在平衡树中递归时,如果遇到一个打过延迟标记的结点,我们需要将这个标记更新到它的子结点中,并且清空这个标记,这样就能保证经过的结点信息全部正确。注意,结点的延迟标记是可以累加的。
延迟标记的思想就是,某个结点需要用到的时候再去维护其正确性。向上更新(update)是利用子结点的信息更新父结点,我们需要保证子结点信息的正确性,而向下更新(pushdown)是利用父结点的信息去更新子结点的信息,所以需要保证父结点的信息信息的正确性。
这样,和区间查询一样,一次区间修改的时间复杂度也是 O(log n) 的。
注意,在旋转时,需要先把相关的结点的标记向下传递完。
延迟标记的实现
#include <cstdio>
int n, ty, pos, l, r, val;
long long ans;
struct node {
node* son[2];
node* fa;
int val, size, tag;
long long sum;
} * root;
void update(node* x) {
x->sum = x->val, x->size = 1;
if (x->son[0]) x->sum += x->son[0]->sum, x->size += x->son[0]->size;
if (x->son[1]) x->sum += x->son[1]->sum, x->size += x->son[1]->size;
}
void push_down(node* x) {
if (x->tag != 0) {
if (x->son[0]) {
x->son[0]->sum += 1LL * x->tag * x->son[0]->size;
x->son[0]->val += x->tag;
x->son[0]->tag += x->tag;
}
if (x->son[1]) {
x->son[1]->sum += 1LL * x->tag * x->son[1]->size;
x->son[1]->val += x->tag;
x->son[1]->tag += x->tag;
}
x->tag = 0;
}
}
void rotate(node* x, int dir) {
node* y = x->fa;
x->fa = y->fa;
if (y->fa != NULL) {
if (y == y->fa->son[0])
y->fa->son[0] = x;
else
y->fa->son[1] = x;
}
if (x->son[1 - dir] != NULL) x->son[1 - dir]->fa = y;
y->son[dir] = x->son[1 - dir];
x->son[1 - dir] = y;
y->fa = x;
update(y), update(x);
}
void splay(node* x, node* fa) {
static node* a[1000005];
static int cnt;
cnt = 0;
node* y = x;
while (y != NULL) a[++cnt] = y, y = y->fa;
while (cnt) push_down(a[cnt]), cnt--;
while (x->fa != fa) {
y = x->fa;
if (y->fa == fa) {
if (x == y->son[0])
rotate(x, 0);
else
rotate(x, 1);
} else {
if (x == y->son[0])
if (y == y->fa->son[0])
rotate(y, 0), rotate(x, 0);
else
rotate(x, 0), rotate(x, 1);
else if (y == y->fa->son[0])
rotate(x, 1), rotate(x, 0);
else
rotate(y, 1), rotate(x, 1);
}
}
}
node* kth(node* x, int k) {
if (x->son[0] != NULL) {
if (k <= x->son[0]->size)
return kth(x->son[0], k);
else
k -= x->son[0]->size;
}
if (k == 1)
return x;
else
k--;
return kth(x->son[1], k);
}
node* insert(node* x, int pos, int val) {
if (x->son[0] != NULL) {
if (pos <= x->son[0]->size) {
node* y = insert(x->son[0], pos, val);
update(x);
return y;
} else
pos -= x->son[0]->size;
}
if (pos == 0) {
node* y = new node;
y->sum = y->val = val;
y->size = 1;
y->fa = x;
x->son[0] = y;
update(x);
return y;
}
if (x->son[1] != NULL) {
node* y = insert(x->son[1], pos - 1, val);
update(x);
return y;
} else {
node* y = new node;
y->sum = y->val = val;
y->size = 1;
y->fa = x;
x->son[1] = y;
update(x);
return y;
}
}
node* getinterval(int l, int r) {
if (l == 1) {
if (r == root->size)
return root;
else {
node* y = kth(root, r + 1);
splay(y, NULL);
root = y;
return y->son[0];
}
} else {
node* x = kth(root, l - 1);
splay(x, NULL);
root = x;
if (r == root->size)
return x->son[1];
else {
node* y = kth(root, r + 1);
splay(y, x);
return y->son[0];
}
}
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%d", &ty);
if (ty == 0) {
scanf("%d%d", &pos, &val);
if (root == NULL) {
root = new node;
root->sum = root->val = val;
root->size = 1;
} else {
node* x = insert(root, pos, val);
splay(x, NULL);
root = x;
}
} else if (ty == 1) {
scanf("%d%d", &l, &r);
node* x = getinterval(l, r);
printf("%lld\n", x->sum);
} else if (ty == 2) {
scanf("%d%d%d", &l, &r, &val);
node* x = getinterval(l, r);
x->sum += 1LL * val * x->size;
x->val += val;
x->tag += val;
}
}
return 0;
}
平衡树的讲解就到这里,希望大家能有所收获,再见。