题目描述
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入xx数
- 删除xx数(若有多个相同的数,因只删除一个)
- 查询xx数的排名(排名定义为比当前数小的数的个数+1+1。若有多个相同的数,因输出最小的排名)
- 查询排名为xx的数
- 求xx的前驱(前驱定义为小于xx,且最大的数)
- 求xx的后继(后继定义为大于xx,且最小的数)
输入输出格式
输入格式:
第一行为nn,表示操作的个数,下面nn行每行有两个数optopt和xx,optopt表示操作的序号( 1 \leq opt \leq 61≤opt≤6 )
输出格式:
对于操作3,4,5,6每行输出一个数,表示对应答案
代码
ch[N][2]:ch[x][0]代表 xx 的左儿子,ch[x][1]代表 xx 的右儿子。
val[N]:val[x]代表 xx 存储的值。
cnt[N]:cnt[x]代表 xx 存储的重复权值的个数。
par[N]:par[x]代表 xx 的父节点。
size[N]:size[x]代表 xx 子树下的储存的权值数(包括重复权值)。
#include<bits/stdc++.h> #define inf 1<<30 using namespace std; const int maxn=2e5+10; int ch[maxn][2],par[maxn],val[maxn],cnt[maxn],size[maxn]; int ncnt,root;//ncnt新建结点位置,root 表示根节点 int n; inline int read() { int x=0,f=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();} while(ch<='9'&&ch>='0'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return x*f; } bool chk(int x) { return ch[par[x]][1]==x; } void pushup(int x) { size[x]=size[ch[x][0]]+size[ch[x][1]]+cnt[x]; } void rotate(int x) { int y=par[x],z=par[y],k=chk(x),w=ch[x][k^1]; ch[y][k]=w;par[w]=y; ch[z][chk(y)]=x;par[x]=z; ch[x][k^1]=y;par[y]=x; pushup(y);pushup(x); } void splay(int x,int goal=0) { while(par[x]!=goal) { int y=par[x],z=par[y]; if(z!=goal) { //两个结点位置相同 if(chk(x)==chk(y))rotate(y); else rotate(x); } rotate(x); } if(!goal)root=x; } void insert(int x) { int cur=root,p=0;//p记录当前节点 while(cur&&val[cur]!=x) p=cur,cur=ch[cur][x>val[cur]]; if(cur)cnt[cur]++; else { cur=++ncnt; if(p)ch[p][x>val[p]]=cur; ch[cur][0]=ch[cur][1]=0; par[cur]=p;val[cur]=x; cnt[cur]=size[cur]=1; } splay(cur); } void find(int x)//把某点旋到根节点 { int cur=root; while(ch[cur][x>val[cur]]&&x!=val[cur]) cur=ch[cur][x>val[cur]];//找到该节点 splay(cur); } int kth(int k) { int cur=root; while(1) { if(ch[cur][0]&&k<=size[ch[cur][0]]) cur=ch[cur][0]; else if(k>size[ch[cur][0]]+cnt[cur]) k-=size[ch[cur][0]]+cnt[cur],cur=ch[cur][1]; else return cur; } } int pre(int x) { find(x); if(val[root]<x)return root; int cur=ch[root][0]; while(ch[cur][1])cur=ch[cur][1]; return cur; } int succ(int x) { find(x); if(val[root]>x)return root; int cur=ch[root][1]; while(ch[cur][0])cur=ch[cur][0]; return cur; } void remove(int x) { int last=pre(x),next=succ(x); splay(last);splay(next,last); int del=ch[next][0];//表示要删的点 if(cnt[del]>1) cnt[del]--,splay(del);//更新size标记 else ch[next][0]=0; } int main() { n=read(); insert(inf); insert(-inf); for(int i=1;i<=n;i++) { int op=read(),x=read(); if(op==1)insert(x); else if(op==2)remove(x); else if(op==3)find(x),printf("%d\n",size[ch[root][0]]); else if(op==4)printf("%d\n",val[kth(x+1)]); else if(op==5)printf("%d\n",val[pre(x)]); else if(op==6)printf("%d\n",val[succ(x)]); } return 0; }