【学习笔记】平衡树(2)

Splay解决区间问题
Splay是什么?
和 Treap 一样, Splay 也是一种平衡树,不同的是, Splay 没有另外记录随机生成的权值,而只是记录了结点原本的值。
Treap 的旋转是为了保持堆的性质,而 Splay 的旋转是为了让树不会退化以至于复杂度过高。一般而言,我们会在对某个结点进行操作后把它旋转到根结点。
至于为什么这样进行旋转能够让树不退化,证明过于复杂,在此不再赘述。
Splay的双旋
考虑如何把一个结点旋转到根,一个最简单的想法就是每次把它向它的父亲结点旋转,直到它成为根结点。这就是单旋。
但是单旋有一个严重的问题,比如下图中我们想把结点 X 旋转到根。
在这里插入图片描述

注意到,在原 Splay 中,有 Z -> Y -> X -> b… 这样一条链,而旋转过后,仍有 Z -> Y -> X -> b… 这样的链,长度和原来相同,这就意味着 Splay 可能会退化。
对于这个问题,我们有一个解决的方法——双旋。当我们要旋转的结点 X 和它的父亲是同一种儿子(都为左儿子或都为右儿子)时,先旋转它的父亲结点,再旋转结点 X ,如下图所示。
在这里插入图片描述

原本结点 X 的两个儿子 a 和 b ,旋转后分别是 X 和 Y 的儿子,深度至少减少了 1 。
Splay 解决区间问题
考虑如下的问题,有一个序列,我们可以在某个数之后插入一个数,或者询问某个区间的数之和。
由于需要执行插入数的操作,所以我们需要使用平衡树来完成这一题。
假设我们需要查询区间 [l, r] 之间的数之和。为了方便叙述,先假设l != 1, r != n。
我们先查询第 l - 1 个数的编号 x 和第 r + 1 个数的编号 y 。然后将结点 x 旋转到根结点,将结点 y 旋转到结点 x 的右儿子处,那么 y 的左子树就是所有区间 [l, r] 之间的数。
在这里插入图片描述
我们只需要对于每个结点维护子树权值和即可。
Splay的实现

#include <cstdio>
int n, ty, pos, l, r, val;
long long ans;
struct node {
    
    
    node* son[2];
    node* fa;
    int val, size;
    long long sum;
} * root;
void update(node* x) {
    
    
    x->sum = x->val, x->size = 1;
    if (x->son[0]) x->sum += x->son[0]->sum, x->size += x->son[0]->size;
    if (x->son[1]) x->sum += x->son[1]->sum, x->size += x->son[1]->size;
}
void rotate(node* x, int dir) {
    
    
    node* y = x->fa;
    x->fa = y->fa;
    if (y->fa != NULL) {
    
    
        if (y == y->fa->son[0])
            y->fa->son[0] = x;
        else
            y->fa->son[1] = x;
    }
    if (x->son[1 - dir] != NULL) x->son[1 - dir]->fa = y;
    y->son[dir] = x->son[1 - dir];
    x->son[1 - dir] = y;
    y->fa = x;
    update(y), update(x);
}
void splay(node* x, node* fa) {
    
    
    node* y;
    while (x->fa != fa) {
    
    
        y = x->fa;
        if (y->fa == fa) {
    
    
            if (x == y->son[0])
                rotate(x, 0);
            else
                rotate(x, 1);
        } else {
    
    
            if (x == y->son[0])
                if (y == y->fa->son[0])
                    rotate(y, 0), rotate(x, 0);
                else
                    rotate(x, 0), rotate(x, 1);
            else if (y == y->fa->son[0])
                rotate(x, 1), rotate(x, 0);
            else
                rotate(y, 1), rotate(x, 1);
        }
    }
}
node* kth(node* x, int k) {
    
    
    if (x->son[0] != NULL) {
    
    
        if (k <= x->son[0]->size)
            return kth(x->son[0], k);
        else
            k -= x->son[0]->size;
    }
    if (k == 1)
        return x;
    else
        k--;
    return kth(x->son[1], k);
}
node* insert(node* x, int pos, int val) {
    
    
    if (x->son[0] != NULL) {
    
    
        if (pos <= x->son[0]->size) {
    
    
            node* y = insert(x->son[0], pos, val);
            update(x);
            return y;
        } else
            pos -= x->son[0]->size;
    }
    if (pos == 0) {
    
    
        node* y = new node;
        y->sum = y->val = val;
        y->size = 1;
        y->fa = x;
        x->son[0] = y;
        update(x);
        return y;
    }
    if (x->son[1] != NULL) {
    
    
        node* y = insert(x->son[1], pos - 1, val);
        update(x);
        return y;
    } else {
    
    
        node* y = new node;
        y->sum = y->val = val;
        y->size = 1;
        y->fa = x;
        x->son[1] = y;
        update(x);
        return y;
    }
}
int main() {
    
    
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
    
    
        scanf("%d", &ty);
        if (ty == 0) {
    
    
            scanf("%d%d", &pos, &val);
            if (root == NULL) {
    
    
                root = new node;
                root->fa = root->son[0] = root->son[1] = NULL;
                root->sum = root->val = val;
                root->size = 1;
            } else {
    
    
                node* x = insert(root, pos, val);
                splay(x, NULL);
                root = x;
            }
        } else {
    
    
            scanf("%d%d", &l, &r);
            if (l == 1) {
    
    
                if (r == root->size)
                    ans = root->sum;
                else {
    
    
                    node* y = kth(root, r + 1);
                    splay(y, NULL);
                    root = y;
                    ans = y->son[0]->sum;
                }
            } else {
    
    
                node* x = kth(root, l - 1);
                splay(x, NULL);
                root = x;
                if (r == root->size)
                    ans = x->son[1]->sum;
                else {
    
    
                    node* y = kth(root, r + 1);
                    splay(y, x);
                    ans = y->son[0]->sum;
                }
            }
            printf("%lld\n", ans);
        }
    }
    return 0;
}

Treap 和 Splay 的延迟标记
我们把之前的问题再增加难度,变成如下问题。
给定一个初始为空的序列,有三种操作

  1. 在第 pos 个数之后插入一个数 val。
  2. 查询区间 [l, r] 之间的数的和。
  3. 把区间 [l, r] 上的数都增加 。
    我们仍然考虑用平衡树来解决这个问题。与之前的问题区别在于增加了“区间加”这一操作。若果我们对每个点单独进行修改,那么时间复杂度到达了O(n log n),这和建树的复杂度一样,显然不是我们想要的结果。
    延迟标记
    区间修改和区间查询的思想类似。
    区间 [l, r] 可以划分为平衡树上的某些结点,显然这些结点在平衡树的中序遍历上是连续的。
    如果在修改时对于每个结点都进行修改,实际上区间修改的复杂度仍然和建树的复杂度一样。所以我们利用一种特殊的方法来解决——延迟标记。
    延迟标记又称为 lazy-tag ,即对于一个子树,我们在根结点打上一个标记,比如”这个区间的所有结点权值都要加上 val,而不是立即递归修改所有结点。同时,因为这棵子树中每个数都加上 val,所以这个子树中所有数的和增加了子树中结点个数个val,这样我们可以保证所有子树根结点的信息是正确的。
    同样,在子树向上传递信息的过程中,也可以保证它祖先结点信息的正确性。不过,这棵子树的非根结点的信息并不能保证正确性。
    对于 Treap ,我们需要递归修改,如果某个子树都要修改,那么直接在根结点打上延迟标记。而对于 Splay ,我们可以用旋转把区间放在一个子树内,这样只打一个延迟标记即可。
    向下传递标记
    虽然子树非根结点的信息不能保证正确性,但是我们暂时不需要用到这些结点。
    我们只需要保证用到这些结点时它们时正确的。
    我们在平衡树中递归时,如果遇到一个打过延迟标记的结点,我们需要将这个标记更新到它的子结点中,并且清空这个标记,这样就能保证经过的结点信息全部正确。注意,结点的延迟标记是可以累加的。
    延迟标记的思想就是,某个结点需要用到的时候再去维护其正确性。向上更新(update)是利用子结点的信息更新父结点,我们需要保证子结点信息的正确性,而向下更新(pushdown)是利用父结点的信息去更新子结点的信息,所以需要保证父结点的信息信息的正确性。
    这样,和区间查询一样,一次区间修改的时间复杂度也是 O(log n) 的。
    注意,在旋转时,需要先把相关的结点的标记向下传递完。
    延迟标记的实现
#include <cstdio>
int n, ty, pos, l, r, val;
long long ans;
struct node {
    
    
    node* son[2];
    node* fa;
    int val, size, tag;
    long long sum;
} * root;
void update(node* x) {
    
    
    x->sum = x->val, x->size = 1;
    if (x->son[0]) x->sum += x->son[0]->sum, x->size += x->son[0]->size;
    if (x->son[1]) x->sum += x->son[1]->sum, x->size += x->son[1]->size;
}
void push_down(node* x) {
    
    
    if (x->tag != 0) {
    
    
        if (x->son[0]) {
    
    
            x->son[0]->sum += 1LL * x->tag * x->son[0]->size;
            x->son[0]->val += x->tag;
            x->son[0]->tag += x->tag;
        }
        if (x->son[1]) {
    
    
            x->son[1]->sum += 1LL * x->tag * x->son[1]->size;
            x->son[1]->val += x->tag;
            x->son[1]->tag += x->tag;
        }
        x->tag = 0;
    }
}
void rotate(node* x, int dir) {
    
    
    node* y = x->fa;
    x->fa = y->fa;
    if (y->fa != NULL) {
    
    
        if (y == y->fa->son[0])
            y->fa->son[0] = x;
        else
            y->fa->son[1] = x;
    }
    if (x->son[1 - dir] != NULL) x->son[1 - dir]->fa = y;
    y->son[dir] = x->son[1 - dir];
    x->son[1 - dir] = y;
    y->fa = x;
    update(y), update(x);
}
void splay(node* x, node* fa) {
    
    
    static node* a[1000005];
    static int cnt;
    cnt = 0;
    node* y = x;
    while (y != NULL) a[++cnt] = y, y = y->fa;
    while (cnt) push_down(a[cnt]), cnt--;
    while (x->fa != fa) {
    
    
        y = x->fa;
        if (y->fa == fa) {
    
    
            if (x == y->son[0])
                rotate(x, 0);
            else
                rotate(x, 1);
        } else {
    
    
            if (x == y->son[0])
                if (y == y->fa->son[0])
                    rotate(y, 0), rotate(x, 0);
                else
                    rotate(x, 0), rotate(x, 1);
            else if (y == y->fa->son[0])
                rotate(x, 1), rotate(x, 0);
            else
                rotate(y, 1), rotate(x, 1);
        }
    }
}
node* kth(node* x, int k) {
    
    
    if (x->son[0] != NULL) {
    
    
        if (k <= x->son[0]->size)
            return kth(x->son[0], k);
        else
            k -= x->son[0]->size;
    }
    if (k == 1)
        return x;
    else
        k--;
    return kth(x->son[1], k);
}
node* insert(node* x, int pos, int val) {
    
    
    if (x->son[0] != NULL) {
    
    
        if (pos <= x->son[0]->size) {
    
    
            node* y = insert(x->son[0], pos, val);
            update(x);
            return y;
        } else
            pos -= x->son[0]->size;
    }
    if (pos == 0) {
    
    
        node* y = new node;
        y->sum = y->val = val;
        y->size = 1;
        y->fa = x;
        x->son[0] = y;
        update(x);
        return y;
    }
    if (x->son[1] != NULL) {
    
    
        node* y = insert(x->son[1], pos - 1, val);
        update(x);
        return y;
    } else {
    
    
        node* y = new node;
        y->sum = y->val = val;
        y->size = 1;
        y->fa = x;
        x->son[1] = y;
        update(x);
        return y;
    }
}
node* getinterval(int l, int r) {
    
    
    if (l == 1) {
    
    
        if (r == root->size)
            return root;
        else {
    
    
            node* y = kth(root, r + 1);
            splay(y, NULL);
            root = y;
            return y->son[0];
        }
    } else {
    
    
        node* x = kth(root, l - 1);
        splay(x, NULL);
        root = x;
        if (r == root->size)
            return x->son[1];
        else {
    
    
            node* y = kth(root, r + 1);
            splay(y, x);
            return y->son[0];
        }
    }
}
int main() {
    
    
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
    
    
        scanf("%d", &ty);
        if (ty == 0) {
    
    
            scanf("%d%d", &pos, &val);
            if (root == NULL) {
    
    
                root = new node;
                root->sum = root->val = val;
                root->size = 1;
            } else {
    
    
                node* x = insert(root, pos, val);
                splay(x, NULL);
                root = x;
            }
        } else if (ty == 1) {
    
    
            scanf("%d%d", &l, &r);
            node* x = getinterval(l, r);
            printf("%lld\n", x->sum);
        } else if (ty == 2) {
    
    
            scanf("%d%d%d", &l, &r, &val);
            node* x = getinterval(l, r);
            x->sum += 1LL * val * x->size;
            x->val += val;
            x->tag += val;
        }
    }
    return 0;
}

平衡树的讲解就到这里,希望大家能有所收获,再见。

猜你喜欢

转载自blog.csdn.net/yueyuedog/article/details/111176167