Introduction
Splay 在维护数集方面效率并不算突出,但可以高效维护数列并进行数列的一些操作,甚至是线段树所做不了的(比如区间翻转)。
所以掌握 Splay 还是很重要的,而且理解了就不难了。
Range Selection
Splay 的最经典的应用莫过于 区间翻转 了。
很显然,数列原顺序是给定的,说明我们的 Splay 应该是 \(L \text{的编号} <X \text{的编号} <R \text{的编号}\),所以应该采用 建树 的方式,而不是一个个插入。时间复杂度显然 \(\Theta(n)\) 。
那么区间翻转怎么做呢?我们应该先找出这一段区间。
怎么找?这就是许多区间操作的共同点了。
假设现在要翻转的区间为 \([l,r]\)。
我们先用一个查找函数找出数列中位置为 \(l-1,r+1\) 的结点编号,记作 \(x,y\)。
然后将 \(x\) 上旋(splay操作)至根,将 \(y\) 上旋至 \(x\) 的右儿子处。
那么这个神奇的 Splay 成了这个样子:
然后就是在中间这个子树的根上 do sth.
由于操作区间都会向前/后延申一个位置,所以我们一般在最前、最后面插入 \(-\infty\) 作为虚拟结点。
inline void f(int l,int r) {
l = select(l), r = select(r + 2);
/*位置 1 和 n + 2 都是虚拟结点*/
splay(l, 0), splay(r, l);
/*do sth.*/
}
Modification
对于区间的修改操作,如区间翻转,通常不直接做,而是像线段树一样打标记。
那么我们也要像线段树一样有一个标记下传,那么什么时候下传呢?在 select
函数中边找边下传即可。
inline int select(int k) {
for (register int rt = root; rt;) {
pushdown(rt); /*<- 在这里下传*/
if (k == size[ch[rt][0]] + 1) return rt;
if (k < size[ch[rt][0]] + 1) rt = ch[rt][0];
else k -= size[ch[rt][0]] + 1, rt = ch[rt][1];
}
}
Query
这个非常简单,直接先选取区间,然后在中间的子树上获取信息即可。
Problem: 「Luogu P2042」 [NOI2005]维护数列
Code Sample
#include <iostream>
#include <algorithm>
#include <string>
#include <stack>
using namespace std;
const int N = 5e5 + 5;
const long long inf = 1e18;
typedef long long LL;
int ch[N][2], fa[N], size[N];
LL val[N], sum[N], lmax[N], rmax[N], mmax[N];
bool cov[N], rev[N];
LL ary[N];
int n, q, total = 0, root;
stack<int> rec;
inline int create() {
if (rec.empty()) return ++total;
int ret = rec.top();
return rec.pop(), ret;
}
inline void clear(int rt) {
ch[rt][0] = ch[rt][1] = fa[rt] = 0;
val[rt] = lmax[rt] = rmax[rt] = 0;
mmax[rt] = -inf;
cov[rt] = rev[rt] = 0;
}
inline void maintain(int rt) {
size[rt] = size[ch[rt][0]] + size[ch[rt][1]] + 1;
sum[rt] = sum[ch[rt][0]] + sum[ch[rt][1]] + val[rt];
lmax[rt] = max(lmax[ch[rt][0]], lmax[ch[rt][1]] + sum[ch[rt][0]] + val[rt]);
rmax[rt] = max(rmax[ch[rt][1]], rmax[ch[rt][0]] + sum[ch[rt][1]] + val[rt]);
mmax[rt] = max(max(mmax[ch[rt][0]], mmax[ch[rt][1]]), lmax[ch[rt][1]] + rmax[ch[rt][0]] + val[rt]);
}
inline void pushdown(int rt) {
int &l = ch[rt][0], &r = ch[rt][1];
if (cov[rt]) {
rev[rt] = cov[rt] = 0;
if (l) cov[l] = 1, val[l] = val[rt], sum[l] = val[rt] * size[l];
if (r) cov[r] = 1, val[r] = val[rt], sum[r] = val[rt] * size[r];
if (val[rt] >= 0) {
if (l) lmax[l] = rmax[l] = mmax[l] = sum[l];
if (r) lmax[r] = rmax[r] = mmax[r] = sum[r];
} else {
if (l) lmax[l] = rmax[l] = 0, mmax[l] = val[l];
if (r) lmax[r] = rmax[r] = 0, mmax[r] = val[r];
}
}
if (rev[rt]) {
rev[rt] = 0;
if (l) rev[l] ^= 1;
if (r) rev[r] ^= 1;
swap(lmax[l], rmax[l]);
swap(lmax[r], rmax[r]);
swap(ch[l][0], ch[l][1]);
swap(ch[r][0], ch[r][1]);
}
}
inline int get(int rt) {
return rt == ch[fa[rt]][1];
}
inline void rotate(int x) {
int y = fa[x], z = fa[y], c = get(x);
ch[y][c] = ch[x][c ^ 1];
fa[ch[x][c ^ 1]] = y;
ch[x][c ^ 1] = y;
fa[y] = x, fa[x] = z;
if (z) ch[z][y == ch[z][1]] = x;
maintain(y), maintain(x);
}
inline void splay(int x,int g) {
for (register int f = fa[x]; (f = fa[x]) != g; rotate(x))
if(fa[f] != g) rotate(get(f) == get(x) ? f : x);
if (!g) root = x;
}
inline int select(int k) {
for (register int rt = root; rt;) {
pushdown(rt);
if (k == size[ch[rt][0]] + 1) return rt;
if (k < size[ch[rt][0]] + 1) rt = ch[rt][0];
else k -= size[ch[rt][0]] + 1, rt = ch[rt][1];
}
throw;
}
int build(int l, int r, int f) {
if (l>r) return 0;
int rt = create(), mid = (l + r) >> 1;
if (l == r) {
val[rt] = sum[rt] = ary[mid];
size[rt] = 1;
ch[rt][0] = ch[rt][1] = 0;
fa[rt] = f;
lmax[rt] = rmax[rt] = max(0ll, val[rt]);
mmax[rt] = val[rt];
return rt;
}
val[rt] = ary[mid], fa[rt] = f;
ch[rt][0] = build(l, mid - 1, rt);
ch[rt][1] = build(mid + 1, r, rt);
return maintain(rt), rt;
}
inline void insert(int pos, int tot) {
if (!tot) return;
int l = select(pos + 1), r = select(pos + 2);
splay(l, 0), splay(r, l);
ch[r][0] = build(1, tot, r);
maintain(r), maintain(l);
}
void recycle(int rt) {
if (!rt) return;
rec.push(rt);
recycle(ch[rt][0]);
recycle(ch[rt][1]);
clear(rt);
}
inline void erase(int pos, int tot) {
if (!tot) return;
int l = select(pos), r = select(pos + tot + 1);
splay(l, 0), splay(r, l);
recycle(ch[r][0]), ch[r][0] = 0;
maintain(r), maintain(l);
}
inline void assign(int pos, int tot, LL c) {
if (!tot) return;
int l = select(pos), r = select(pos + tot + 1);
splay(l, 0), splay(r, l);
int x = ch[r][0];
val[x] = c, sum[x] = size[x] * c;
if (c >= 0) lmax[x] = rmax[x] = mmax[x] = sum[x];
else lmax[x] = rmax[x] = 0, mmax[x] = val[x];
cov[x] = 1;
maintain(r), maintain(l);
}
inline void reverse(int pos, int tot) {
if (!tot) return;
int l = select(pos), r = select(pos + tot + 1);
splay(l, 0), splay(r, l);
int x = ch[r][0];
if (cov[x]) return;
swap(ch[x][0], ch[x][1]);
swap(lmax[x], rmax[x]);
rev[x] ^= 1;
maintain(r), maintain(l);
}
inline LL getSum(int pos, int tot) {
if (!tot) return 0ll;
int l = select(pos), r = select(pos + tot + 1);
splay(l, 0), splay(r, l);
return sum[ch[r][0]];
}
inline LL getMaxSum() {
int l = select(1), r = select(size[root]);
splay(l, 0), splay(r, l);
return mmax[ch[r][0]];
}
void out(int rt) {
if (!rt) return;
pushdown(rt);
out(ch[rt][0]);
cout << val[rt] << ' ';
out(ch[rt][1]);
}
signed main() {
ios::sync_with_stdio(0);
cin >> n >> q;
for (register int i = 1; i <= n; i++)
cin >> ary[i + 1];
ary[1] = ary[n + 2] = -inf;
clear(0);
root = build(1, n + 2, 0);
while (q--) {
string cmd;
int pos, tot;
LL c;
cin >> cmd;
if (cmd == "INSERT") {
cin >> pos >> tot;
for (register int i = 1; i <= tot; i++)
cin >> ary[i];
insert(pos, tot);
}
if (cmd == "DELETE") {
cin >> pos >> tot;
erase(pos, tot);
}
if (cmd == "MAKE-SAME") {
cin >> pos >> tot >> c;
assign(pos, tot, c);
}
if (cmd == "REVERSE") {
cin >> pos >> tot;
reverse(pos, tot);
}
if (cmd == "GET-SUM") {
cin >> pos >> tot;
cout << getSum(pos, tot) << endl;
}
if (cmd == "MAX-SUM")
cout << getMaxSum() << endl;
}
}