Splay
上一篇:平衡树学习笔记(2)-------Treap
Splay是一个实用而且灵活性很强的平衡树
效率上也比较客观,但是一定要一次性写对
debug可能不是那么容易
Splay作为平衡树,它的平衡方式就是旋转
暴力旋转,赤裸裸的旋转,各种旋转
就是依靠玄学的旋转来保证自己的复杂度
不废话,上主题
\(\color{#9900ff}{定义}\)
struct node
{
node *fa,*ch[2];
int val,num,siz;
node() {val=num=siz=0;}
inline void clr() {val=num=siz=0;}
inline bool isr() {return this==fa->ch[1];}
inline void upd() {siz=ch[0]->siz+ch[1]->siz+num;}
};
siz为子树大小,val为点权,num为点的个数(重复数字存在一个点上)
clr 清空节点,isr 判断是否为自己父亲的右孩子
\(\color{#9900ff}{基本操作}\)
1、rotate
其实这个就是第一节说的旋转
rot(x)代表把x转到它父亲的位置上去
这也是Splay维护平衡的基础
下面是重点了!!
把x转到它父亲y上
以下代码中字母对应,其中那个R是代码中的w(因为为中间量,要特殊对待)
inline void rot(nod x)
{
nod y=x->fa,z=y->fa;
//找到y,z(注意,x转上去后,z的孩子变成x,所以要涉及到z)
int k=x->isr(); nod w=x->ch[!k];
//isr是bool型的,看看是不是自己父亲的右孩子,这个旋转针对的是所有情况,不仅仅是上图的情况
if(y!=root) z->ch[y->isr()]=x;
else root=x;
//x转上去,就要考虑y是不是根的问题
//如果y是根,x转上去后,自然成为了根
//如果不是根,就要让x替换y的位置,原来y是z的哪个孩子,现在x就是z的哪个孩子
x->ch[!k]=y;
y->ch[k]=w;
//该认孩子的认孩子
w->fa=y,y->fa=x,x->fa=z;
//该认父亲的认父亲
y->upd(),x->upd();
//因为x在y的上一层,x的upd要基于y,所以y先来
}
以上部分一定要理解透彻!!!
2、Splay
这个操作使基于rotate的
Splay(x),作用是把x转到根节点的位置上
显然要转好多次的qwq
因为一些玄学的东西(雾
平衡树中,每次用到谁转谁(反正不影响性质,说白了貌似还是瞎转)
这样玄学的操作可以使Splay平衡
inline void splay(nod x)
{
while(x!=root)
{
if(x->fa!=root) rot(x->isr()^x->fa->isr()? x:x->fa);
rot(x);
}
}
上面if那一行是啥意思呢?
我们要考虑一条链的情况
这种情况我们要先转父亲,再转自己
否则直接转自己就行
至此,基本操作已经结束qwq
\(\color{#9900ff}{其它操作}\)
1、插入
这个是真的暴力插。。。。。。
inline void ins(int x)
{
if(root==null)
{
//空树则对根节点操作
root=newnode();
root->siz=root->num=1;
root->val=x;
return;
}
//从根开始暴力插♂
nod fa=null;
nod o=root;
while(1)
{
if(o->val==x)
{
//刚刚说重复的节点存在一起,这就是重复的情况
o->num++;
//玄学操作,转上去
splay(o);
return;
}
//一直往下跳(注意方向)
fa=o;
o=o->ch[x>o->val];
if(o==null)
{
//跳到了空节点上,那么申请新节点
fa->ch[x>fa->val]=o=newnode();
//千万不要忘记父子互认
o->fa=fa;
o->num=o->siz=1;
o->val=x;
splay(o);
return;
}
}
}
2、删除
这个有点。。鬼畜
一般来说,(我所知道的)有两种删除方式,某崔性男子说可以merge(雾
第一种一般在数组版写
找到要删节点的前驱和后继
前驱转到根,后继转到根的右孩子
R的左子树一定是我们要删的,直接删就行了(父子不互认,其他变量清空)
第二种就是我在指针写的
需要两个函数(好像有点麻烦吧qwq)
inline nod lst()
{
nod o=root->ch[0];
while(o->ch[1]!=null) o=o->ch[1];
return o;
}
返回根的前驱
下面的是真正的删除
首先把要删的节点转到根并记录一下
找到根的前驱
把根的前驱转到根
那么一定是这种情况
原根,也就是要删的点,一定是没有左孩子的!!!!
所以类似于链表的操作,把该删的删掉
inline void del(int x)
{
rnk(x);
//删一个,还有
if(root->num>=2) {root->num--; root->upd(); return;}
//删一个不够了
nod l=lst(),rt=root;
splay(l);
//类似于链表的操作,使得被删点隔绝于此树之外
l->ch[1]=rt->ch[1];
l->ch[1]->fa=l;
rt->clr();
l->upd();
//清空与维护
}
3、查询数x的排名
暴力找
inline int rnk(int x)
{
//rank来记录排名
//从根开始暴力求
int rank=0;
nod o=root;
while(1)
{
//应该往左跳
if(o->ch[0]!=null&&x<o->val) o=o->ch[0];
else
{
//到这里说明左子树所有点的值都<x,所以统计
rank+=o->ch[0]->siz;
//刚好等于
if(x==o->val)
{
splay(o);
//因为初始插了个极小值极大值,所以不用+1
return rank;
}
//x比当前点还要大,所以+=num
rank+=o->num;
//往右跳
o=o->ch[1];
}
}
}
4、查询第k大的数
其实跟上面差不多
inline int kth(int x)
{
nod o=root;
while(1)
{
if(o->ch[0]!=null&&x<=o->ch[0]->siz) o=o->ch[0];
else
{
int y=o->ch[0]->siz+o->num;
if(x<=y) return o->val;
x-=y;
o=o->ch[1];
}
}
}
5、6、前驱,后继
这两个为什么一块写?
因为他们几乎一样
inline int pre(nod o,int x)
{
if(o==null) return -0x7fffffff;
if(x>o->val) return nmr::max(o->val,pre(o->ch[1],x));
//当前点成立,但递归下去可能不成立了,所以去max
else return pre(o->ch[0],x);
//当前点本来就不成立,直接递归
}
inline int nxt(nod o,int x)
{
//同上
if(o==null) return 0x7fffffff;
if(x<o->val) return nmr::min(o->val,nxt(o->ch[0],x));
else return nxt(o->ch[1],x);
}
至此,Splay完
其实只要理解了,并不是想象那么难的
放一下完整代码
#include<cstdio>
#include<queue>
#include<vector>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cctype>
#define _ 0
#define LL long long
#define Space putchar(' ')
#define Enter putchar('\n')
#define fuu(x,y,z) for(int x=(y);x<=(z);x++)
#define fu(x,y,z) for(int x=(y);x<(z);x++)
#define fdd(x,y,z) for(int x=(y);x>=(z);x--)
#define fd(x,y,z) for(int x=(y);x>(z);x--)
#define mem(x,y) memset(x,y,sizeof(x))
template<typename T>inline void in(T &x)
{
char ch;x=0;
int f=1;
while(!isdigit(ch=getchar()))f=ch=='-'? -f:f;
while(isdigit(ch)) x=(x*10)+(ch^48),ch=getchar();
x*=f;
}
template<typename T>inline void out(T x)
{
if(x<0) putchar('-'),x=-x;
if(x>9) out(x/10);
putchar(x%10+'0');
}
namespace nmr
{
template<typename T>inline T abs(T a) {return a>0? a:-a;}
template<typename T>inline void swap(T &a,T &b) {T t=a; a=b; b=t;}
template<typename T>inline const T &min(const T &a,const T &b) {return a>b? b:a;}
template<typename T>inline const T &max(const T &a,const T &b) {return a>b? a:b;}
}
int n,cnt;
struct node
{
node *fa,*ch[2];
int val,num,siz;
node() {val=num=siz=0;}
inline void clr() {val=num=siz=0;}
inline bool isr() {return this==fa->ch[1];}
inline void upd() {siz=ch[0]->siz+ch[1]->siz+num;}
};
typedef node* nod;
node st[123456];
nod root,null;
inline nod newnode()
{
cnt++;
st[cnt].fa=st[cnt].ch[1]=st[cnt].ch[0]=null;
return &st[cnt];
}
inline void rot(nod x)
{
nod y=x->fa,z=y->fa;
int k=x->isr(); nod w=x->ch[!k];
if(y!=root) z->ch[y->isr()]=x;
else root=x;
x->ch[!k]=y;
y->ch[k]=w;
w->fa=y,y->fa=x,x->fa=z;
y->upd(),x->upd();
}
inline void splay(nod x)
{
while(x!=root)
{
if(x->fa!=root) rot(x->isr()^x->fa->isr()? x:x->fa);
rot(x);
}
}
inline int rnk(int x)
{
int rank=0;
nod o=root;
while(1)
{
if(o->ch[0]!=null&&x<o->val) o=o->ch[0];
else
{
rank+=o->ch[0]->siz;
if(x==o->val)
{
splay(o);
return rank;
}
rank+=o->num;
o=o->ch[1];
}
}
}
inline int kth(int x)
{
nod o=root;
while(1)
{
if(o->ch[0]!=null&&x<=o->ch[0]->siz) o=o->ch[0];
else
{
int y=o->ch[0]->siz+o->num;
if(x<=y) return o->val;
x-=y;
o=o->ch[1];
}
}
}
inline nod lst()
{
nod o=root->ch[0];
while(o->ch[1]!=null) o=o->ch[1];
return o;
}
inline int pre(nod o,int x)
{
if(o==null) return -0x7fffffff;
if(x>o->val) return nmr::max(o->val,pre(o->ch[1],x));
else return pre(o->ch[0],x);
}
inline int nxt(nod o,int x)
{
if(o==null) return 0x7fffffff;
if(x<o->val) return nmr::min(o->val,nxt(o->ch[0],x));
else return nxt(o->ch[1],x);
}
inline void ins(int x)
{
if(root==null)
{
root=newnode();
root->siz=root->num=1;
root->val=x;
return;
}
nod fa=null;
nod o=root;
while(1)
{
if(o->val==x)
{
o->num++;
splay(o);
return;
}
fa=o;
o=o->ch[x>o->val];
if(o==null)
{
fa->ch[x>fa->val]=o=newnode();
o->fa=fa;
o->num=o->siz=1;
o->val=x;
splay(o);
return;
}
}
}
inline void del(int x)
{
rnk(x);
if(root->num>=2) {root->num--; root->upd(); return;}
nod l=lst(),rt=root;
splay(l);
l->ch[1]=rt->ch[1];
l->ch[1]->fa=l;
rt->clr();
l->upd();
}
int main()
{
in(n);
null=&st[0];
null->ch[1]=null->ch[0]=null->fa=null;
root=null;
ins(0x7fffffff);
ins(-0x7fffffff);
int p,x;
while(n--)
{
in(p),in(x);
if(p==1) {ins(x);}
if(p==2) {del(x);}
if(p==3) {out(rnk(x));Enter;}
if(p==4) {out(kth(x+1));Enter;}
if(p==5) {out(pre(root,x));Enter;}
if(p==6) {out(nxt(root,x));Enter;}
}
return ~~(0^_^0);
}