AH/HNOI2017 单旋 (splay)
前言
为什么说这道题叫做\(\text{splay}\)就不能用\(\text{splay}\)做?
不就是一个加一减一序列嘛干嘛用线段树\(+\text{set}\)做多烦啊
本人是一个\(\text{STL}\)的憎恨者,所以想出了一个可以使用\(\text {splay}\)实现的方法。
不过必须提一句……卡带着常数毁灭\(\text H\)国不是成功了吗
题意
要模拟一个只通过把\(x\)节点转到\(x\)节点的父亲的\(\text {splay}\)操作,并支持删根
怎么解
有一样东西普通的\(\text{splay}\)操作无法完成,那就是邪恶的深度。
因为普通的\(\text{splay}\)会把深度转乱,所以深度无法维护。
但是有一个很重要的发现:单旋最小值不会改变树的形状。
这时候,我们就可以把问题转换一下:
【敲黑板】\(\text{The next part is only for very smart students.}\)
我们可以把深度存储在一个序列之中,维护这个序列。
当我们执行\(\text{splay}\)操作的时候,最小值原来的右子树成为了他父亲的左子树,相应的最小值的深度不变,其它结点深度均\(+1\)。
我们需要维护以下操作:
- 查询
- 区间加一
- 前驱后继
- 插入
那么,喜欢线段树的同志们请注意了,你们维护\(3\)和\(4\)是很困难的。
\(\text{splay}\)闪亮登场。
上代码
本人代码风格极差,望谅解。(太匆忙了)。
喜欢线段树的同志们自己实现一下。
\(\text{splay}\)的代码:
//By Zhengjiarui, Copyright @2019, All rights preserved.
//Do not copy this code
//AHOI/HNOI2017 splay
//solution: use splay to stimulate the spaly tree
#include <bits/stdc++.h>
#define ls c[x][0]
#define rs c[x][1]
#define rep(i, l, r) for (int i = l; i <= r; i++)
using namespace std;
template <typename T>
inline void rd(T &x) {
int t;
char ch;
for (t = 0; !isdigit(ch = getchar()); t = (ch == '-'));
for (x = ch - '0'; isdigit(ch = getchar()); x = x * 10 + ch - '0');
if (t)
x = -x;
}
const int inf = 2000000000, N = 100100;
int n, cnt, rt, Q, fa[N], v[N], sz[N], c[N][2], s[N], dep[N], mn[N];
void push(int x) {
v[ls] += v[x];
dep[ls] += v[x];
mn[ls] += v[x];
v[rs] += v[x];
dep[rs] += v[x];
mn[rs] += v[x];
v[x] = 0;
}
void upd(int x) {
sz[x] = sz[ls] + sz[rs] + 1;
mn[x] = dep[x];
if (ls)
mn[x] = min(mn[x], mn[ls]);
if (rs)
mn[x] = min(mn[x], mn[rs]);
}
void rot(int &rt, int x) {
int y = fa[x], z = fa[y], w = (c[y][1] == x);
if (y == rt)
rt = x;
else
c[z][c[z][1] == y] = x;
fa[x] = z;
fa[y] = x;
fa[c[x][w ^ 1]] = y;
c[y][w] = c[x][w ^ 1];
c[x][w ^ 1] = y;
upd(y);
}
void splay(int &rt, int x) {
while (x != rt) {
int y = fa[x], z = fa[y];
if (y != rt) {
if ((c[z][1] == y) ^ (c[y][1] == x))
rot(rt, x);
else
rot(rt, y);
}
rot(rt, x);
}
upd(x);
}
void ins(int &x, int S, int d, int lst) {
if (!x) {
x = ++cnt, s[cnt] = S;
dep[cnt] = mn[cnt] = d;
sz[cnt] = 1;
fa[cnt] = lst;
return;
}
ins(c[x][S > s[x]], S, d, x);
upd(x);
}
int getpre(int x, int S) {
if (!x)
return 0;
if (v[x])
push(x);
if (s[x] > S)
return getpre(c[x][0], S);
Q = getpre(c[x][1], S);
if (Q)
return Q;
else
return x;
}
int getnxt(int x, int S) {
if (!x)
return 0;
if (v[x])
push(x);
if (s[x] < S)
return getnxt(rs, S);
Q = getnxt(ls, S);
if (Q)
return Q;
else
return x;
}
int find(int x, int k) {
if (v[x])
push(x);
if (sz[ls] + 1 == k)
return x;
if (sz[ls] + 1 < k)
return find(rs, k - sz[ls] - 1);
return find(ls, k);
}
int getl(int x, int d) {
if (!x)
return 0;
if (v[x])
push(x);
if (min(mn[ls], dep[x]) >= d)
return getl(rs, d) + sz[ls] + 1;
else
return getl(ls, d);
}
int getr(int x, int d) {
if (!x)
return 0;
if (v[x])
push(x);
if (min(mn[rs], dep[x]) >= d)
return getr(ls, d) + sz[rs] + 1;
else
return getr(rs, d);
}
int split(int l, int r) {
int t1 = find(rt, l - 1), t2 = find(rt, r + 1);
splay(rt, t1);
splay(c[rt][1], t2);
return c[c[rt][1]][0];
}
void mdf(int l, int r, int ad) {
int y = split(l, r);
v[y] += ad;
mn[y] += ad;
dep[y] += ad;
}
void change(int x, int S) {
if (v[x])
push(x);
if (s[x] == S)
dep[x] = 1;
else
change(c[x][S > s[x]], S);
upd(x);
}
int main() {
freopen("splay.in","r",stdin);
freopen("splay.out","w",stdout);
rd(n);
ins(rt, -inf, inf, 0);
ins(rt, inf, inf, 0);
mn[0] = inf;
rep(i, 1, n) {
int op, x;
rd(op);
if (op == 1) {
rd(x);
int t1 = getpre(rt, x), t2 = getnxt(rt, x);
int D = max(t1 > 2 ? dep[t1] : 0, t2 > 2 ? dep[t2] : 0) + 1;
ins(rt, x, D, 0);
splay(rt, cnt);
printf("%d\n", D);
}
if (!(op & 1)) {
int x = find(rt, 2), y = min(getl(rt, dep[x]), sz[rt] - 1) - 1;
printf("%d\n", dep[x]);
mdf(2, sz[rt] - 1, 1);
if (y > 1)
mdf(2, y + 1, -1);
change(rt, s[x]);
}
if ((op & 1) && (op > 1)) {
int x = find(rt, sz[rt] - 1), y = min(getr(rt, dep[x]), sz[rt] - 1) - 1;
printf("%d\n", dep[x]);
mdf(2, sz[rt] - 1, 1);
if (y > 1)
mdf(sz[rt] - y, sz[rt] - 1, -1);
change(rt, s[x]);
}
if (op >= 4) {
if (op == 4)
splay(rt, find(rt, 2));
else
splay(rt, find(rt, sz[rt] - 1));
int l = (op == 5), r = l ^ 1, y = c[rt][l];
c[y][r] = c[rt][r];
fa[y] = 0;
fa[c[rt][r]] = y;
rt = y;
v[rt] -= 1;
upd(rt);
}
}
return 0;
}
大佬博客里的线段树代码(不是我写的)
#include<bits/stdc++.h>
#define N 200010
using namespace std;
int m,tp,root;
int opt[N],v[N],q[N],ch[N][2],fa[N],dep[N*2];
set<int>st;
int read()
{
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f*=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void down(int rt)
{
if(dep[rt])
{
dep[rt<<1]+=dep[rt];
dep[rt<<1|1]+=dep[rt];
dep[rt]=0;
}
}
void modify(int rt,int l,int r,int pos,int val)
{
if(l==r){dep[rt]=val;return ;}
down(rt);
int mid=(l+r)>>1;
if(pos<=mid)modify(rt<<1,l,mid,pos,val);
else modify(rt<<1|1,mid+1,r,pos,val);
}
int query(int rt,int l,int r,int pos)
{
if(l==r)return dep[rt];
down(rt);
int mid=(l+r)>>1;
if(pos<=mid)return query(rt<<1,l,mid,pos);
else return query(rt<<1|1,mid+1,r,pos);
}
void update(int rt,int l,int r,int L,int R,int k)
{
if(L<=l&&R>=r){dep[rt]+=k;return ;}
down(rt);
int mid=(l+r)>>1;
if(L<=mid)update(rt<<1,l,mid,L,R,k);
if(R>mid)update(rt<<1|1,mid+1,r,L,R,k);
}
int insert(int x)
{
set<int>::iterator it=st.insert(x).first;//定义前向迭代器,下标从插入元素开始
if(!root){root=x;modify(1,1,tp,x,1);return 1;}//空树,插入结点深度为1
if(it!=st.begin())//如果插入元素有前驱
{
if(!ch[*--it][1])ch[fa[x]=*it][1]=x;//如果前驱没有右儿子
it++;
}
if(!fa[x])ch[fa[x]=*++it][0]=x;//要么成为后继的左儿子
int deep=query(1,1,tp,fa[x])+1;
modify(1,1,tp,x,deep);
return deep;
}
int findmax()
{
int x=*st.rbegin(),res=query(1,1,tp,x);
if(x==root)return 1;
if(x-1>=fa[x]+1)update(1,1,tp,fa[x]+1,x-1,-1);//x右树的深度都不变,所以先减1
update(1,1,tp,1,tp,1);
ch[fa[x]][1]=ch[x][0];
fa[ch[x][0]]=fa[x];
ch[fa[root]=x][0]=root;
root=x;
modify(1,1,tp,x,1);
return res;
}
int findmin()
{
int x=*st.begin(),res=query(1,1,tp,x);//找到最小值并询问深度
if(x==root)return 1;
if(x+1<=fa[x]-1)//x有右子数
update(1,1,tp,x+1,fa[x]-1,-1);//x右树的深度都不变,所以先减1
update(1,1,tp,1,tp,1);
ch[fa[x]][0]=ch[x][1];
fa[ch[x][1]]=fa[x];
ch[fa[root]=x][1]=root;
root=x;
modify(1,1,tp,x,1);
return res;
}
void delmax()
{
printf("%d\n",findmax());
update(1,1,tp,1,tp,-1);
st.erase(root);
root=ch[root][0];
fa[root]=0;
}
void delmin()
{
printf("%d\n",findmin());
update(1,1,tp,1,tp,-1);
st.erase(root);
root=ch[root][1];
fa[root]=0;
}
int main()
{
m=read();
for(int i=1;i<=m;i++)
{
opt[i]=read();
if(opt[i]==1)q[++tp]=v[i]=read();
}
sort(q+1,q+1+tp);
for(int i=1;i<=m;i++)
if(opt[i]==1)v[i]=lower_bound(q+1,q+1+tp,v[i])-q;//将插入的值离散化
for(int i=1;i<=m;i++)
{
if(opt[i]==1){printf("%d\n",insert(v[i]));}
else if(opt[i]==2)printf("%d\n",findmin());
else if(opt[i]==3)printf("%d\n",findmax());
else if(opt[i]==4)delmin();
else if(opt[i]==5)delmax();
}
return 0;
}