替罪羊树是一种重量平衡树。我们知道大部分平衡树都是要旋转的,只有非旋Treap和替罪羊不用旋转。替罪羊的思路很简单,就是当我们发现子树不平衡之后,就重构一遍。
替罪羊重构:
重构是替罪羊的精华,替罪羊的重构其实很暴力,比如我们要重构x的子树,就先将x的子树里的点全部回收起来。然后重构一棵二叉树,每次选节点的时候都选当前序列的中点。这样二分建树即可。
回收点:每次回收点时都优先向左边的点回收,然后再回收右边,这样就可以保证回收成的数组单调递增(不下降)。
void recycle(int k){
if(!k) return;
recycle(son[k][0]);
cur[++sum]=k;
recycle(son[k][1]);
}
重构:直接二分重构
int build(int l,int r){
if(l>r) return 0;
int mid=(l+r)>>1,id=cur[mid];
fa[son[id][0]=build(l,mid-1)]=id;
fa[son[id][1]=build(mid+1,r)]=id;
siz[id]=siz[son[id][0]]+siz[son[id][1]]+1;
return id;
}
void rebuild(int k){
sum=0;recycle(k);
int f=fa[k],d=(son[f][1]==k),id=build(1,sum);
fa[son[f][d]=id]=f;
if(k==root) root=id;
}
插入:
替罪羊的插入类似于其它的平衡树,但是在插入之后我们要往上回溯它的祖先,然后找它的祖先中深度最浅的。然后重构这个点的子树。
void insert(int x){
if(!root){val[root=++sz]=x,siz[root]=1,fa[root]=0;return;}
int tmp=root;
while(1){
siz[tmp]++;
int f=tmp,d=(x>=val[tmp]);tmp=son[tmp][d];
if(!tmp){
siz[++sz]=1,val[sz]=x,fa[sz]=f,son[f][d]=sz;
break;
}
}
int flag=0;
for(int i=sz;i;i=fa[i]) if(!isbad(i)) flag=i;
if(flag) rebuild(flag);
}
删除:
替罪羊的删除其实不同于其他的平衡树,我们可以直接打一个标记,标记这个点已经被删除了,但是这个点实际上是存在的。然后我们操作的时候就不管已经标记过的点。然后如果标记过的点太多了,就直接暴力重构整棵树,并且回收的时候不回收标记的节点。这样就完成了删除。
我打的并不是这个版本,我打的还是类似于splay的删除。直接将点x赋值为它的前驱,然后再删掉这个前驱即可。
int getpl(int x){
int tmp=root;
while(tmp){
if(x==val[tmp]) return tmp;
int d=(x>=val[tmp]);
tmp=son[tmp][d];
}
return tmp;
}
void erase(int k){
if(son[k][0]&&son[k][1]){
int tmp=son[k][0];
while(son[tmp][1]) tmp=son[tmp][1];
val[k]=val[tmp];k=tmp;
}
int s=son[k][0]?son[k][0]:son[k][1],f=fa[k],d=(son[f][1]==k);
son[f][d]=s;fa[s]=f;
for(int i=f;i;i=fa[i]) siz[i]--;
if(root==k) root=s;
}
其它的操作类似于普通平衡树。
模板:BZOJ3224
#include<bits/stdc++.h>
#define MAXN 800005
#define INF 2e9
using namespace std;
int read(){
char c;int x=0,y=1;while(c=getchar(),(c<'0'||c>'9')&&c!='-');
if(c=='-') y=-1;else x=c-'0';while(c=getchar(),c>='0'&&c<='9')
x=x*10+c-'0';return x*y;
}
const double alpha=0.7;
int n,sz,root;
struct Scapegoat{
int sum;
int son[MAXN][2],siz[MAXN],val[MAXN],ct[MAXN],cur[MAXN],fa[MAXN];
int isbad(int x){
return (double)siz[x]*alpha>max(siz[son[x][0]],siz[son[x][1]]);
}
void recycle(int k){
if(!k) return;
recycle(son[k][0]);
cur[++sum]=k;
recycle(son[k][1]);
}
int build(int l,int r){
if(l>r) return 0;
int mid=(l+r)>>1,id=cur[mid];
fa[son[id][0]=build(l,mid-1)]=id;
fa[son[id][1]=build(mid+1,r)]=id;
siz[id]=siz[son[id][0]]+siz[son[id][1]]+1;
return id;
}
void rebuild(int k){
sum=0;recycle(k);
int f=fa[k],d=(son[f][1]==k),id=build(1,sum);
fa[son[f][d]=id]=f;
if(k==root) root=id;
}
void insert(int x){
if(!root){val[root=++sz]=x,siz[root]=1,fa[root]=0;return;}
int tmp=root;
while(1){
siz[tmp]++;
int f=tmp,d=(x>=val[tmp]);tmp=son[tmp][d];
if(!tmp){
siz[++sz]=1,val[sz]=x,fa[sz]=f,son[f][d]=sz;
break;
}
}
int flag=0;
for(int i=sz;i;i=fa[i]) if(!isbad(i)) flag=i;
if(flag) rebuild(flag);
}
int getpl(int x){
int tmp=root;
while(tmp){
if(x==val[tmp]) return tmp;
int d=(x>=val[tmp]);
tmp=son[tmp][d];
}
return tmp;
}
void erase(int k){
if(son[k][0]&&son[k][1]){
int tmp=son[k][0];
while(son[tmp][1]) tmp=son[tmp][1];
val[k]=val[tmp];k=tmp;
}
int s=son[k][0]?son[k][0]:son[k][1],f=fa[k],d=(son[f][1]==k);
son[f][d]=s;fa[s]=f;
for(int i=f;i;i=fa[i]) siz[i]--;
if(root==k) root=s;
}
int findRANK(int x){
int tmp=root,res=0;
while(tmp){
if(x<=val[tmp]) tmp=son[tmp][0];
else res+=siz[son[tmp][0]]+1,tmp=son[tmp][1];
}
return res+1;
}
int findNUM(int x){
int tmp=root;
while(tmp){
if(x<=siz[son[tmp][0]]){
tmp=son[tmp][0];continue;
}
x-=siz[son[tmp][0]];
if(x==1) return val[tmp];x--;
tmp=son[tmp][1];
}
}
int findPRE(int x){
int ret=-INF,tmp=root;
while(tmp){
if(val[tmp]<x) ret=max(ret,val[tmp]),tmp=son[tmp][1];
else tmp=son[tmp][0];
}
return ret;
}
int findSUF(int x){
int ret=INF,tmp=root;
while(tmp){
if(val[tmp]>x) ret=min(ret,val[tmp]),tmp=son[tmp][0];
else tmp=son[tmp][1];
}
return ret;
}
}T;
int main()
{
n=read();
for(int i=1;i<=n;i++){
int type=read(),x=read();
if(type==1) T.insert(x);
if(type==2){int p=T.getpl(x);if(p) T.erase(p);}
if(type==3) printf("%d\n",T.findRANK(x));
if(type==4) printf("%d\n",T.findNUM(x));
if(type==5) printf("%d\n",T.findPRE(x));
if(type==6) printf("%d\n",T.findSUF(x));
}
return 0;
}