线段树套平衡树是什么脑残东西,复杂度就是假的, 让人感觉非常不靠谱。所以我们为什么不用更好写的树状数组代替线段树,更好写的主席树(权值线段树)代替平衡树呢?而且,不仅是好写,复杂度也是很对的 啊。
我们来简单理解一下树套树是什么:
思想其实很简单,树状数组的每个节点都是一颗权值线段树,并且每颗权值线段树维护的信息都是树状数组式的累加,这样每次查询和修改都只需要对logn个线段树进行操作。
对比一下普通的树状数组和树套树:
额,,,差不多就是这样了,确实很简单吧。
下面这个代码是二逼平衡树的板子,除了修改和查询k大还有求rank和前驱后继(所以我在讲平衡树的时候说过维护权值的事线段树也能做嘛)
code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
const int maxn = 5e4 + 7;
const int inf = 0x7fffffff;
using namespace std;
int n, m;
int a[maxn];
int val[maxn << 1], tot;
int opt[maxn];
int qa[maxn];
int qb[maxn];
int qc[maxn];
struct node {
int sum;
int l, r;
} st[maxn * 400];
int cnt;
int root[maxn];
int xx[20], cnt1;
int yy[20], cnt2;
inline int read()
{
int X = 0; char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') X = X * 10 + ch - '0', ch = getchar();
return X;
}
inline int lowbit(int x)
{
return x & -x;
}
void update(int num, int &rt, int l, int r, int x)
{
st[++cnt] = st[rt];
rt = cnt;
st[rt].sum += x;
if (l == r) return;
int mid = l + r >> 1;
if (num <= mid) update(num, st[rt].l, l, mid, x);
else update(num, st[rt].r, mid + 1, r, x);
}
int get_rnk(int l, int r, int x)
{
if (l == r) return 0;
int mid = l + r >> 1;
if (x <= mid) {
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].l;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].l;
return get_rnk(l, mid, x);
}
int d = 0;
for (int i = 1; i <= cnt1; i++) d -= st[st[xx[i]].l].sum;
for (int i = 1; i <= cnt2; i++) d += st[st[yy[i]].l].sum;
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].r;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].r;
return get_rnk(mid + 1, r, x) + d;
}
int get_kth(int l, int r, int k)
{
if (l == r) return val[l];
int d = 0, mid = l + r >> 1;
for (int i = 1; i <= cnt1; i++) d -= st[st[xx[i]].l].sum;
for (int i = 1; i <= cnt2; i++) d += st[st[yy[i]].l].sum;
if (k <= d) {
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].l;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].l;
return get_kth(l, mid, k);
}
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].r;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].r;
return get_kth(mid + 1, r, k - d);
}
inline void add(int i, int x)
{
int k = lower_bound(val + 1, val + tot + 1, a[i]) - val;
for (; i <= n; i += lowbit(i)) update(k, root[i], 1, tot, x);
}
inline void init_query(int i)
{
cnt1 = cnt2 = 0;
for (int j = qa[i] - 1; j; j -= lowbit(j)) xx[++cnt1] = root[j];
for (int j = qb[i]; j; j -= lowbit(j)) yy[++cnt2] = root[j];
}
int main(void)
{
cin >> n >> m;
tot = n;
for (int i = 1; i <= n; i++) a[i] = val[i] = read();
for (int i = 1; i <= m; i++) {
opt[i] = read();
qa[i] = read();
qb[i] = read();
if (opt[i] != 3) {
qc[i] = read();
if (opt[i] != 2) val[++tot] = qc[i];
}
else val[++tot] = qb[i];
}
sort(val + 1, val + tot + 1);
tot = unique(val + 1, val + tot + 1) - val - 1;
st[0] = {0, 0, 0};
for (int i = 1; i <= n; i++) add(i, 1);
for (int i = 1; i <= m; i++) {
if (opt[i] != 3) init_query(i);
if (opt[i] != 2 && opt[i] != 3) qc[i] = lower_bound(val + 1, val + tot + 1, qc[i]) - val;
if (opt[i] == 1) printf("%d\n", get_rnk(1, tot, qc[i]) + 1);
else if (opt[i] == 2) printf("%d\n", get_kth(1, tot, qc[i]));
else if (opt[i] == 3) {
add(qa[i], -1);
a[qa[i]] = qb[i];
add(qa[i], 1);
}
else if (opt[i] == 4) {
int k = get_rnk(1, tot, qc[i]);
if (!k) printf("%d\n", -inf);
else {
init_query(i);
printf("%d\n", get_kth(1, tot, k));
}
}
else {
int k = get_rnk(1, tot, qc[i] + 1);
if (k > qb[i] - qa[i]) printf("%d\n", inf);
else {
init_query(i);
printf("%d\n", get_kth(1, tot, k + 1));
}
}
}
return 0;
}