平衡树学习笔记(3)-------Splay

Splay

上一篇:平衡树学习笔记(2)-------Treap

Splay是一个实用而且灵活性很强的平衡树

效率上也比较客观,但是一定要一次性写对

debug可能不是那么容易

Splay作为平衡树,它的平衡方式就是旋转

暴力旋转,赤裸裸的旋转,各种旋转

就是依靠玄学的旋转来保证自己的复杂度

不废话,上主题

\(\color{#9900ff}{定义}\)

struct node
{
    node *fa,*ch[2];
    int val,num,siz;
    node() {val=num=siz=0;}
    inline void clr() {val=num=siz=0;}
    inline bool isr() {return this==fa->ch[1];}
    inline void upd() {siz=ch[0]->siz+ch[1]->siz+num;}
};

siz为子树大小,val为点权,num为点的个数(重复数字存在一个点上)

clr 清空节点,isr 判断是否为自己父亲的右孩子

\(\color{#9900ff}{基本操作}\)

1、rotate

其实这个就是第一节说的旋转

rot(x)代表把x转到它父亲的位置上去

这也是Splay维护平衡的基础

下面是重点了!!

把x转到它父亲y上

以下代码中字母对应,其中那个R是代码中的w(因为为中间量,要特殊对待)

inline void rot(nod x)
{
    nod y=x->fa,z=y->fa; 
    //找到y,z(注意,x转上去后,z的孩子变成x,所以要涉及到z)
    int k=x->isr(); nod w=x->ch[!k];
    //isr是bool型的,看看是不是自己父亲的右孩子,这个旋转针对的是所有情况,不仅仅是上图的情况
    if(y!=root) z->ch[y->isr()]=x;
    else root=x;
    //x转上去,就要考虑y是不是根的问题
    //如果y是根,x转上去后,自然成为了根
    //如果不是根,就要让x替换y的位置,原来y是z的哪个孩子,现在x就是z的哪个孩子
    x->ch[!k]=y;
    y->ch[k]=w;
    //该认孩子的认孩子
    w->fa=y,y->fa=x,x->fa=z;
    //该认父亲的认父亲
    y->upd(),x->upd();
    //因为x在y的上一层,x的upd要基于y,所以y先来
}

以上部分一定要理解透彻!!!

2、Splay

这个操作使基于rotate的

Splay(x),作用是把x转到根节点的位置上

显然要转好多次的qwq

因为一些玄学的东西(雾

平衡树中,每次用到谁转谁(反正不影响性质,说白了貌似还是瞎转)

这样玄学的操作可以使Splay平衡

inline void splay(nod x)
{   
    while(x!=root)
    {
        if(x->fa!=root) rot(x->isr()^x->fa->isr()? x:x->fa);
        rot(x);
    }
}

上面if那一行是啥意思呢?

我们要考虑一条链的情况

这种情况我们要先转父亲,再转自己

否则直接转自己就行

至此,基本操作已经结束qwq

\(\color{#9900ff}{其它操作}\)

1、插入

这个是真的暴力插。。。。。。

inline void ins(int x)
{
    if(root==null)    
    {
        //空树则对根节点操作
        root=newnode();
        root->siz=root->num=1;
        root->val=x;
        return;
    }
    //从根开始暴力插♂
    nod fa=null;
    nod o=root;
    while(1)
    {
        if(o->val==x)
        {
            //刚刚说重复的节点存在一起,这就是重复的情况
            o->num++;
            //玄学操作,转上去
            splay(o);
            return;
        }
        //一直往下跳(注意方向)
        fa=o;
        o=o->ch[x>o->val];
        if(o==null)
        {
            //跳到了空节点上,那么申请新节点
            fa->ch[x>fa->val]=o=newnode();
            //千万不要忘记父子互认
            o->fa=fa;
            o->num=o->siz=1;
            o->val=x;
            splay(o);
            return;
        }
    }
}

2、删除

这个有点。。鬼畜

一般来说,(我所知道的)有两种删除方式,某崔性男子说可以merge(雾

第一种一般在数组版写

找到要删节点的前驱和后继

前驱转到根,后继转到根的右孩子

R的左子树一定是我们要删的,直接删就行了(父子不互认,其他变量清空)

第二种就是我在指针写的

需要两个函数(好像有点麻烦吧qwq)

inline nod lst()
{
    nod o=root->ch[0];
    while(o->ch[1]!=null) o=o->ch[1];
    return o; 
}

返回根的前驱

下面的是真正的删除

首先把要删的节点转到根并记录一下

找到根的前驱

把根的前驱转到根

那么一定是这种情况

原根,也就是要删的点,一定是没有左孩子的!!!!

所以类似于链表的操作,把该删的删掉

inline void del(int x)
{
    rnk(x);
    //删一个,还有
    if(root->num>=2) {root->num--; root->upd(); return;}
    //删一个不够了
    nod l=lst(),rt=root;
    splay(l);
    //类似于链表的操作,使得被删点隔绝于此树之外
    l->ch[1]=rt->ch[1];
    l->ch[1]->fa=l;
    rt->clr();
    l->upd();
    //清空与维护
}

3、查询数x的排名

暴力找

inline int rnk(int x)
{
    //rank来记录排名
    //从根开始暴力求
    int rank=0;
    nod o=root;
    while(1)
    {
        //应该往左跳
        if(o->ch[0]!=null&&x<o->val) o=o->ch[0];
        else
        {
            //到这里说明左子树所有点的值都<x,所以统计
            rank+=o->ch[0]->siz;
            //刚好等于
            if(x==o->val)
            {
                splay(o);
                //因为初始插了个极小值极大值,所以不用+1
                return rank;
            }
            //x比当前点还要大,所以+=num
            rank+=o->num;
            //往右跳
            o=o->ch[1];
        }
    }
}

4、查询第k大的数

其实跟上面差不多

inline int kth(int x)
{
    nod o=root;
    while(1)
    {
        if(o->ch[0]!=null&&x<=o->ch[0]->siz) o=o->ch[0];
        else
        {
            int y=o->ch[0]->siz+o->num;
            if(x<=y) return o->val;
            x-=y;
            o=o->ch[1];
        }
    }
}

5、6、前驱,后继

这两个为什么一块写?

因为他们几乎一样

inline int pre(nod o,int x)
{
    if(o==null) return -0x7fffffff;
    if(x>o->val) return nmr::max(o->val,pre(o->ch[1],x));
    //当前点成立,但递归下去可能不成立了,所以去max
    else return pre(o->ch[0],x);
    //当前点本来就不成立,直接递归
}
inline int nxt(nod o,int x)
{    
    //同上
    if(o==null) return 0x7fffffff;
    if(x<o->val) return nmr::min(o->val,nxt(o->ch[0],x));
    else return nxt(o->ch[1],x);
}

至此,Splay完

其实只要理解了,并不是想象那么难的

放一下完整代码

#include<cstdio>
#include<queue>
#include<vector>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<cctype>
#define _ 0
#define LL long long
#define Space putchar(' ')
#define Enter putchar('\n')
#define fuu(x,y,z) for(int x=(y);x<=(z);x++)
#define fu(x,y,z)  for(int x=(y);x<(z);x++)
#define fdd(x,y,z) for(int x=(y);x>=(z);x--)
#define fd(x,y,z)  for(int x=(y);x>(z);x--)
#define mem(x,y)   memset(x,y,sizeof(x))
template<typename T>inline void in(T &x)
{
    char ch;x=0;
    int f=1;
    while(!isdigit(ch=getchar()))f=ch=='-'? -f:f;
    while(isdigit(ch)) x=(x*10)+(ch^48),ch=getchar();
    x*=f;
}
template<typename T>inline void out(T x)
{
    if(x<0) putchar('-'),x=-x;
    if(x>9) out(x/10);
    putchar(x%10+'0');
}
namespace nmr
{
    template<typename T>inline  T abs(T a) {return a>0? a:-a;}
    template<typename T>inline void swap(T &a,T &b) {T t=a; a=b; b=t;}
    template<typename T>inline const T &min(const T &a,const T &b) {return a>b? b:a;}
    template<typename T>inline const T &max(const T &a,const T &b) {return a>b? a:b;}
}
int n,cnt;
struct node
{
    node *fa,*ch[2];
    int val,num,siz;
    node() {val=num=siz=0;}
    inline void clr() {val=num=siz=0;}
    inline bool isr() {return this==fa->ch[1];}
    inline void upd() {siz=ch[0]->siz+ch[1]->siz+num;}
};
typedef node* nod;
node st[123456];
nod root,null;
inline nod newnode()
{
    cnt++;
    st[cnt].fa=st[cnt].ch[1]=st[cnt].ch[0]=null;
    return &st[cnt];
}
inline void rot(nod x)
{
    nod y=x->fa,z=y->fa; 
    int k=x->isr(); nod w=x->ch[!k];
    if(y!=root) z->ch[y->isr()]=x;
    else root=x;
    x->ch[!k]=y;
    y->ch[k]=w;
    w->fa=y,y->fa=x,x->fa=z;
    y->upd(),x->upd();
}
inline void splay(nod x)
{   
    while(x!=root)
    {
        if(x->fa!=root) rot(x->isr()^x->fa->isr()? x:x->fa);
        rot(x);
    }
}
inline int rnk(int x)
{
    int rank=0;
    nod o=root;
    while(1)
    {
        if(o->ch[0]!=null&&x<o->val) o=o->ch[0];
        else
        {
            rank+=o->ch[0]->siz;
            if(x==o->val)
            {
                splay(o);
                return rank;
            }
            rank+=o->num;
            o=o->ch[1];
        }
    }
}
inline int kth(int x)
{
    nod o=root;
    while(1)
    {
        if(o->ch[0]!=null&&x<=o->ch[0]->siz) o=o->ch[0];
        else
        {
            int y=o->ch[0]->siz+o->num;
            if(x<=y) return o->val;
            x-=y;
            o=o->ch[1];
        }
    }
}
inline nod lst()
{
    nod o=root->ch[0];
    while(o->ch[1]!=null) o=o->ch[1];
    return o; 
}
inline int pre(nod o,int x)
{
    if(o==null) return -0x7fffffff;
    if(x>o->val) return nmr::max(o->val,pre(o->ch[1],x));
    else return pre(o->ch[0],x);
}
inline int nxt(nod o,int x)
{
    if(o==null) return 0x7fffffff;
    if(x<o->val) return nmr::min(o->val,nxt(o->ch[0],x));
    else return nxt(o->ch[1],x);
}
inline void ins(int x)
{
    if(root==null)
    {
        root=newnode();
        root->siz=root->num=1;
        root->val=x;
        return;
    }
    nod fa=null;
    nod o=root;
    while(1)
    {
        if(o->val==x)
        {
            o->num++;
            splay(o);
            return;
        }
        fa=o;
        o=o->ch[x>o->val];
        if(o==null)
        {
            fa->ch[x>fa->val]=o=newnode();
            o->fa=fa;
            o->num=o->siz=1;
            o->val=x;
            splay(o);
            return;
        }
    }
}
inline void del(int x)
{
    rnk(x);
    if(root->num>=2) {root->num--; root->upd(); return;}
    nod l=lst(),rt=root;
    splay(l);
    l->ch[1]=rt->ch[1];
    l->ch[1]->fa=l;
    rt->clr();
    l->upd();
}
int main()
{
    in(n);
    null=&st[0];
    null->ch[1]=null->ch[0]=null->fa=null;
    root=null;
    ins(0x7fffffff);
    ins(-0x7fffffff);
    int p,x;
    while(n--)
    {
        in(p),in(x);
        if(p==1) {ins(x);}
        if(p==2) {del(x);}
        if(p==3) {out(rnk(x));Enter;}
        if(p==4) {out(kth(x+1));Enter;}
        if(p==5) {out(pre(root,x));Enter;}
        if(p==6) {out(nxt(root,x));Enter;}
    }
    return ~~(0^_^0);
}

猜你喜欢

转载自www.cnblogs.com/olinr/p/10012901.html