Splay和一般的BST的区别大概就是:每次插入一个元素后,把它旋转到根
其核心操作就是$splay$和$rotate$
Rotate:把$x$转到$x$的父亲的位置上
$rotate$很好理解,自己画个图玩玩就会了,代码要注意细节
void pushup(int x) { sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x]; } void rotate(int x) { int y=fa[x],z=fa[y],k=(x == ch[y][1]); if(z) ch[z][y == ch[z][1]]=x; fa[x]=z; ch[y][k]=ch[x][k^1], fa[ch[x][k^1]]=y; ch[x][k^1]=y, fa[y]=x; pushup(y); pushup(x); }
Splay:把$x$转到目标$goal$的儿子的位置上(若$goal=0$,则把$x$转到根节点)
分三种情况:
1.$x$与$x$的父亲所属的儿子的种类相同(例如$x$是$y=fa[x]$的左儿子,$fa[x]$是$z=fa[fa[x]]$的左儿子)
2.与上面一种相反
3.$fa[fa[x]]=goal$
对于第一种情况,先把$y$转到$z$,再把$x$转到$y$
对于第二种情况,把$x$一路转上去
对于第三种情况,转一次$x$即可
void splay(int x,int goal) { while(fa[x] != goal) { int y=fa[x],z=fa[y]; if(z != goal) (x == ch[y][1]) == (y == ch[z][1]) ? rotate(y) : rotate(x); rotate(x); } if(!goal) rt=x; }
除了以上操作,$Splay$还有一些附带操作,大都很好理解,但是$Delete$和$Find_kth$要注意一下
$Findkth$:找到当前第$k$大的数
细节很多,要想清楚
int Find_kth(int x) { int cur=rt; while(sz[ch[cur][0]] >= x || sz[ch[cur][0]]+cnt[cur] < x) if(sz[ch[cur][0]] >= x) cur=ch[cur][0]; else x-=sz[ch[cur][0]]+cnt[cur], cur=ch[cur][1]; return cur; }
$Delete$:删除某个数$x$(若有多个只删除一个)
先把$x$的前驱转到根节点,在把它的后继转到根节点的儿子上
这样要删的数$x$就只可能独立地挂在它的后继的左儿子上
然后直接删就好了
void Delete(int x) { int lst=Pre(x),nxt=Suf(x); splay(lst,0); splay(nxt,lst); int tar=ch[nxt][0]; if(cnt[tar]>1) --cnt[tar], splay(tar,0); else ch[nxt][0]=0, splay(nxt,0); //这里splay是为了更新size }
注意:为了防止边界上出现一些奇奇怪怪的错误,开局先在平衡树内插入$inf$和$-inf$
#include<bits/stdc++.h> using namespace std; const int N=110000,inf=2e9+7; int rt,node,fa[N],ch[N][2],cnt[N],sz[N],val[N]; void pushup(int x) { sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+cnt[x]; } void rotate(int x) { int y=fa[x],z=fa[y],k=(x == ch[y][1]); if(z) ch[z][y == ch[z][1]]=x; fa[x]=z; ch[y][k]=ch[x][k^1], fa[ch[x][k^1]]=y; ch[x][k^1]=y, fa[y]=x; pushup(y); pushup(x); } void splay(int x,int goal) { while(fa[x] != goal) { int y=fa[x],z=fa[y]; if(z != goal) (x == ch[y][1]) == (y == ch[z][1]) ? rotate(y) : rotate(x); rotate(x); } if(!goal) rt=x; } void Insert(int x) { if(!rt) { rt=++node,val[node]=x,++cnt[node],++sz[node]; return; } int cur=rt; while(233) { ++sz[cur]; if(val[cur] == x) {++cnt[cur]; return;} int z=ch[cur][val[cur]<x]; if(!z) { ch[cur][val[cur]<x]=z=++node,fa[z]=cur; val[z]=x,++cnt[z],++sz[z]; splay(z,0); return ; } cur=z; } } int Find(int x) { int cur=rt; while(val[cur] != x) cur=ch[cur][val[cur]<x]; return cur; } int Pre(int x) { int cur=rt,ret; while(cur) if(val[cur]<x) ret=cur, cur=ch[cur][1]; else cur=ch[cur][0]; return ret; } int Suf(int x) { int cur=rt,ret; while(cur) { if(val[cur]>x) ret=cur, cur=ch[cur][0]; else cur=ch[cur][1]; } return ret; } void Delete(int x) { int lst=Pre(x),nxt=Suf(x); splay(lst,0); splay(nxt,lst); int tar=ch[nxt][0]; if(cnt[tar]>1) --cnt[tar], splay(tar,0); else ch[nxt][0]=0, splay(nxt,0); } int Find_kth(int x) { int cur=rt; while(sz[ch[cur][0]] >= x || sz[ch[cur][0]]+cnt[cur] < x) if(sz[ch[cur][0]] >= x) cur=ch[cur][0]; else x-=sz[ch[cur][0]]+cnt[cur], cur=ch[cur][1]; return cur; } int main() { //freopen(" .in","r",stdin); freopen(" .out","w",stdout); int n,opt,x; Insert(inf); Insert(-inf); scanf("%d",&n); while(n--) { scanf("%d%d",&opt,&x); if(opt == 1) Insert(x); if(opt == 2) Delete(x); if(opt == 3) {x=Find(x); splay(x,0); printf("%d\n",sz[ch[x][0]]);} if(opt == 4) printf("%d\n",val[Find_kth(++x)]); if(opt == 5) printf("%d\n",val[Pre(x)]); if(opt == 6) printf("%d\n",val[Suf(x)]); } }