Splay 指针版题目&代码

题目1

luogu 3369

#include <cstdio>

#define INF 0X7FFFFFFF
#define SIZE(u) ((u)?(u)->size:0)

int n;

struct Splay{
    struct Node{
        int val,cnt,size;
        Node *ch[2],*fa;
        
        inline Node(int _val,Node *_fa){
            val=_val,fa=_fa,cnt=size=1;
            ch[0]=ch[1]=NULL;
        }
        
        inline int relation(){
            return this==fa->ch[1];
        }
        
        inline void maintain(){
            size=SIZE(ch[0])+SIZE(ch[1])+cnt;
        }
    } *root;
    
    inline Splay(){
        root=NULL;
        insert(+INF);
        insert(-INF);
    }
    
    inline void rotate(Node *&u){
        int x=u->relation();
        Node *old=u->fa;
        
        u->fa=old->fa;
        if (old->fa){
            old->fa->ch[old->relation()]=u;
        }
        
        old->ch[x]=u->ch[x^1];
        if (u->ch[x^1]){
            u->ch[x^1]->fa=old;
        }
        
        u->ch[x^1]=old;
        old->fa=u;
        
        old->maintain(),u->maintain();
        if (!u->fa) root=u;
    }
    
    inline void splay(Node *&u,Node *target=NULL){
        while (u->fa!=target){
            if (u->fa->fa==target){
                rotate(u);
            } else if (u->fa->relation()==u->relation()){
                rotate(u->fa),rotate(u);
            } else{
                rotate(u);
                rotate(u);
            }
        }
    }
    
    inline Node *find(int val){
        Node *u=root;
        while (u && u->val!=val){
            if (val<u->val){
                u=u->ch[0];
            } else{
                u=u->ch[1];
            }
        }
        
        if (!u) return NULL;
        
        splay(u);
        
        return u;
    }
    
    inline Node *insert(int val){
        Node *u=root,*fa=NULL; int d=-1;
        while (u && u->val!=val){
            fa=u; u->size++;
            if (val<u->val){
                u=u->ch[0],d=0;
            } else{
                u=u->ch[1],d=1;
            }
        }
        
        if (u){
            u->cnt++;
            u->maintain();
            return u;
        }
        
        u=new Node(val,fa);
        if (d==-1) root=u;
        else fa->ch[d]=u,splay(u);
        
        return u;
    }
    
    inline void erase(int val){
        erase(find(val));
    }
    
    inline void erase(Node *u){
        splay(u);
        if (u->cnt!=1){
            u->cnt--;
            u->maintain();
            return;
        }
        
        Node *v=u->ch[0];
        while (v->ch[1]) v=v->ch[1];
        
        splay(v,u);
        
        v->ch[1]=u->ch[1];
        v->ch[1]->fa=v;
        delete u; u=NULL;
        v->fa=NULL;
        root=v;
    }
    
    inline int rnk(int val){
        int res=0;
        Node *u=root;
        
        while (1){
            if (val==u->val){
                res=res+SIZE(u->ch[0]);
                break;
            } else if (val<u->val){
                u=u->ch[0];
            } else{
                res=res+SIZE(u->ch[0])+u->cnt;
                u=u->ch[1];
            }
        }
        
        if (u) splay(u);
        
        return res;
    }
    
    inline Node *select(int x){
        x++;
        Node *u=root;
        
        while (1){
            if (x<=SIZE(u->ch[0])){
                u=u->ch[0];
            } else if (x>SIZE(u->ch[0])+u->cnt){
                x-=SIZE(u->ch[0])+u->cnt;
                u=u->ch[1];
            } else break;
        }
        
        splay(u);
        
        return u;
    }
    
    inline int pre(int val){
        Node *u=root,*v=NULL;
        int res=-INF;
        while (u){
            if (u->val<val){
                if (res<u->val) res=u->val,v=u;
                u=u->ch[1];
            } else{
                u=u->ch[0]; 
            }
        }
        if (v) splay(v);
        return res;
    }
    
    inline int nxt(int val){
        Node *u=root,*v=NULL;
        int res=+INF;
        while (u){
            if (u->val>val){
                if (res>u->val) res=u->val,v=u;
                u=u->ch[0];
            } else{
                u=u->ch[1];
            }
        }
        if (v) splay(v);
        return res;
    }
} tr;


int main(){
    scanf("%d",&n);
    for (int i=1,cmd,x;i<=n;i++){
        scanf("%d%d",&cmd,&x);
        if (cmd==1){
            tr.insert(x);
        } else if (cmd==2){
            tr.erase(x);
        } else if (cmd==3){
            printf("%d\n",tr.rnk(x));
        } else if (cmd==4){
            printf("%d\n",tr.select(x)->val);
        } else if (cmd==5){
            printf("%d\n",tr.pre(x));
        } else if (cmd==6){
            printf("%d\n",tr.nxt(x));
        }
    }
}

题目2

luogu 3165

#include <cstdio>
#include <algorithm>

#define SIZE(u) ((u)?(u)->size:0)
#define INF 0X7FFFFFFF

#define MAXN 100005

int n,b[MAXN],a[MAXN];

struct Node{
    int rev,size;
    Node *ch[2],*fa;
    
    inline Node(Node *_fa=NULL){
        rev=0,size=1;
        ch[0]=ch[1]=NULL,fa=_fa;
    }
    
    inline int relation(){
        return this==fa->ch[1];
    }
    
    inline void pushdown(){
        if (rev){
            if (ch[0]) ch[0]->rev^=1;
            if (ch[1]) ch[1]->rev^=1;
            Node *_=ch[0]; ch[0]=ch[1],ch[1]=_;
            rev=0;
        }
    }
    
    inline void update(){
        size=SIZE(ch[0])+SIZE(ch[1])+1;
    }
    
} *root,*pos[MAXN];

inline void build(Node *&u,Node *fa,int l,int r){
    if (r<l) return;
    int mid=l+r>>1;
    u=new Node(fa); pos[a[mid]]=u;
    build(u->ch[0],u,l,mid-1);
    build(u->ch[1],u,mid+1,r);
    u->update();
}

inline void rotate(Node *&u){
    Node *old=u->fa;
    
    if (old->fa) old->fa->pushdown();
    old->pushdown();
    u->pushdown();
    
    int x=u->relation();
    
    u->fa=old->fa;
    if (old->fa){
        old->fa->ch[old->relation()]=u;
    }
    
    old->ch[x]=u->ch[x^1];
    if (u->ch[x^1]){
        u->ch[x^1]->fa=old;
    }
    
    u->ch[x^1]=old;
    old->fa=u;
    
    old->update();
    u->update();
    
    if (!u->fa) root=u;
}

inline void splay(Node *&u,Node *target=NULL){
    u->pushdown();
    
    while (u->fa!=target){
        u->fa->pushdown();
        if (u->fa->fa==target){
            rotate(u);
        } else if (u->fa->relation()==u->relation()){
            u->fa->fa->pushdown();
            rotate(u->fa);
            rotate(u);
        } else{
            u->fa->fa->pushdown();
            rotate(u);
            rotate(u);
        }
    }
}

inline Node *find(int x){
    x++;
    Node *u=root;
    
    while (1){
        u->pushdown();
        if (x<=SIZE(u->ch[0])){
            u=u->ch[0];
        } else if (x>SIZE(u->ch[0])+1){
            x-=SIZE(u->ch[0])+1;
            u=u->ch[1];
        } else break;
    }
    
    if (u) splay(u);
    return u;
}

inline bool cmp(const int &x,const int &y){
    return a[x]<a[y] || (a[x]==a[y] && x<y);
}

template <typename T>
inline void read(T &x){
    int fl=0,ch;
    while (ch=getchar(),ch<48 || 57<ch) fl^=!(ch^45); x=(ch&15);
    while (ch=getchar(),47<ch && ch<58) x=(x<<1)+(x<<3)+(ch&15);
    if (fl) x=-x;
}

template <typename T>
inline void write(T x){
    if (x<0) x=-x,putchar('-');
    if (x>9) write(x/10);
    putchar(x%10+48);
}

int main(){
    read(n);
    for (int i=1;i<=n;i++) read(a[i]),b[i]=i;
    std::sort(b+1,b+1+n,cmp);
    for (int i=1;i<=n;i++) a[b[i]]=i;
    build(root,NULL,0,n+1);
    for (int i=1,res;i<n;i++){
        splay(pos[i]); res=SIZE(pos[i]->ch[0]);
        Node *u=find(i-1);
        Node *v=find(res+1);
        splay(u),splay(v,u);
        root->ch[1]->ch[0]->rev^=1;
        write(res),putchar(' ');
    }
    write(n),putchar('\n');
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/xuanyi/p/9460256.html