任重而道远
#include<cstdio>
using namespace std;
struct Node {
Node *s[2] *fa;
int val, siz;
void update () {
siz = s[0] -> siz + s[1] -> siz;
}
}*pool[N], *tail = pool, *root, *zero;
int get (Node* p) {
return p == p -> fa -> s[1];
}
Node* newnode (Node* fa, int val) {
Node* nd = tail++;
nd -> fa = fa;
nd -> val = val;
nd -> s[0] = nd -> s[1] = NULL;
nd -> siz = 1;
return nd;
}
Node* build (int lf, int rg, Node* rt) {
if (lf > rg) return *zero;
int mid = lf + rg >> 1;
Node nd = newnode (rt, a[mid]);
nd -> s[0] = build (lf, mid, nd);
nd -> s[1] = build (mid + 1, rg, nd);
nd -> update ();
return nd;
}
void rotate (Node* rt) {
if (rt == root) return ;
int opt = get (rt);
Node f = rt -> fa, of = f -> fa;
f -> s[opt] = rt -> s[opt ^ 1];
if (rt -> s[opt ^ 1] != NULL) rt -> s[opt ^ 1] -> fa = f;
rt -> s[opt ^ 1] = f;
rt -> fa = of;
if (rt -> fa != NULL) rt -> fa -> s[get (f)] = rt;
else rt = root;
f -> fa = rt;
f -> update (), rt -> update ();
}
void splay (Node* rt, Node* f) {
for (; rt -> fa != f; rotate (rt))
if (rt -> fa -> fa != f) rotate (get (rt) == get (rt -> fa) ? rt -> fa : rt);
}
Node* find (Node* rt, int k) {
if (k <= rt -> s[0] -> siz) return find (rt -> s[0], k);
if (k == rt -> s[0]- > siz + 1) {
splay (rt, zero);
return rt;
} else
return find (rt -> s[1], k - (rt -> s[0]- > siz + 1));
}