一道算是挺难想的线段树。
需要维护两个数组,一个d和一个f。f的定义参见体题面,前后最大值中最小的,再与自己的d取一个大的。然后有三种操作,1是求f的和;2是求区间[l,r]中,有多少个f的数值大于给定的数字v;3是修改某一个位置的d的值,使其增加v。在这里,f随着d的数值而变化,但是这个变化的关系没有什么直观的联系。
下面我们仔细考虑一下,这两个到底有什么关系。
首先我们得出几个显然的推论:
推论一:如果某个位置x的d增加了v之后,存在一个y>x,使得d[y]>d[x],那么有max(d[k])>max(d[j]),k和j的定义见题面。
推论二:推论一反过来,如果存在y<x,使得d[y]>d[x],那么有max(d[j])>max(d[k])
然后开始推导。
考虑当一个位置x增加了v,不妨假设增加了之后 d[x]>d[i] (1<=i<x) 并且满足推论一的条件,那么根据推论一,在区间[x+1,y-1]中,y为满足条件的最小值,f的数值只受max(d[j])影响,而此时max(d[j])=d[x],所以这一区间的f增大为d[x]。
当满足前一条件,但是不满足推论一的条件的时候,也即在x的右边找不到比当前d[x]大的d[y],那么我去找区间[x+1,n]中最左边的一个最大值所在的位置y。我们会发现,对于一个在区间[x+1,y-1]中的点,当更新之前,满足max(d[j])>max(d[k])时,更新之后这个关系还是不变,f还是可以用d[k]更新;但当更新之前是max(d[j])<max(d[k])时,f取值是d[j],但此时已经是max(d[j])>max(d[k])了,所以f的取值应该是与d[k]关联。综上,两种情况用d[k]去更新f都不会出错。所以此时对于区间[x+1,y-1]我们用d[k]去更新。
同理,如果d[x]>d[i] (x<i<=n)并且满足推论二的条件,那么根据推论二,在区间[y+1,x-1]中,y为满足条件的最大值,f的数值只会受到max(d[k])的影响,而此时max(d[k])=d[x],所以这一区间的f增大为d[x]。
当不满足推论二的条件的时候也是类似的道理。找到区间[1,x-1]中最右边的一个最大值所在的位置y,对于区间[y+1,x-1]的f用d[x]去更新即可。
这样,动态维护f的操作我们就完成了。接下来考虑这个区间内大于某一个数字的f的个数怎么求。
经过观察,我们可以发现,那些f的数值大于某些数字的区域一定是连续的一个区间,而且这个区间的端点l和r,满足l左边的所有d[i]都小于对应的数字,r右边的所有d[i]都小于对应的数字。这个很容易去证明,如果存在两个位置l和r,d[l]和d[r]都大于对应的数字,那么显然他们之间的所有位置的f都要大于对应的数字。因为d[j]和d[k]都大于等于对应的数字,而f是二者最小值与d的最大值,因此也一定大于对应的数字。所以只需要找到d中最左边和最右边的大于对应数字的位置即可。那么第二个操作也可以解决了。
最后总结一下。我们建立两棵线段树,一个维护d,另一个维护f。对于一个修改操作,我们需要在d中查找左右第一个大于更新之后d[x]的位置,如果不存在则对应找左边做靠右的最大值和右边最靠左的最大值,这样我们就可以知道需要修改的区间。然后在f的对应区间进行更新。对于查询区间内大于某一个数字的个数,只需要在d中找大于对应数字的最左边和最右边的数字的位置,两个之间的数字都满足条件。然后查询f的和直接输出f那棵树的和即可。具体操作见代码:
#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define pi 3.141592653589793
#define mod 998244353
#define LL long long
#define pb push_back
#define lb lower_bound
#define ub upper_bound
#define sf(x) scanf("%lld",&x)
#define sc(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)
using namespace std;
const int N = 300010;
LL f[N],d[N],mx1[N],mx2[N],n,m;
typedef pair<LL,int> P;
struct ST1
{
#define ls i<<1
#define rs i<<1|1
struct node
{
int lpos,rpos,l,r;
LL max;
} T[N<<2];
inline void push_up(int i)
{
T[i].max=max(T[ls].max,T[rs].max);
T[i].lpos=T[T[ls].max>=T[rs].max?ls:rs].lpos;
T[i].rpos=T[T[rs].max>=T[ls].max?rs:ls].rpos;
}
void build(int i,int l,int r)
{
T[i]=node{0,0,l,r,0};
if (l==r)
{
T[i].lpos=T[i].rpos=l;
T[i].max=d[l]; return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
push_up(i);
}
void update(int i,int pos,LL x)
{
if (T[i].l==T[i].r)
{
T[i].max+=x;
return;
}
int mid=(T[i].l+T[i].r)>>1;
if (mid>=pos) update(ls,pos,x);
else if (mid<pos) update(rs,pos,x);
push_up(i);
}
int query1(int i,int l,int r,LL x)
{
if (l>r||T[i].max<x) return 0;
if (T[i].l==T[i].r) return T[i].max>=x?T[i].l:0;
int mid=(T[i].l+T[i].r)>>1,res=0;
if (l<=mid) res=query1(ls,l,r,x);
if (r>mid&&!res) res=query1(rs,l,r,x);
return res;
}
int query2(int i,int l,int r,LL x)
{
if (l>r||T[i].max<x) return 0;
if (T[i].l==T[i].r) return T[i].max>=x?T[i].l:0;
int mid=(T[i].l+T[i].r)>>1,res=0;
if (r>mid&&!res) res=query2(rs,l,r,x);
if (l<=mid&&!res) res=query2(ls,l,r,x);
return res;
}
P getmax1(int i,int l,int r)
{
if (l<=T[i].l&&T[i].r<=r) return P(T[i].max,T[i].lpos);
int mid=(T[i].l+T[i].r)>>1; P res={0,0};
if (l<=mid) res=getmax1(ls,l,r);
if (r>mid)
{
P tmp=getmax1(rs,l,r);
if (tmp.first>res.first) res=tmp;
}
return res;
}
P getmax2(int i,int l,int r)
{
if (l<=T[i].l&&T[i].r<=r) return P(T[i].max,T[i].rpos);
int mid=(T[i].l+T[i].r)>>1; P res={0,0};
if (r>mid) res=getmax2(rs,l,r);
if (l<=mid)
{
P tmp=getmax2(ls,l,r);
if (tmp.first>res.first) res=tmp;
}
return res;
}
} seg1;
struct ST2
{
#define ls i<<1
#define rs i<<1|1
struct node
{
LL lazy,sum;
int l,r;
} T[N<<2];
void build(int i,int l,int r)
{
T[i]=node{0,0,l,r};
if (l==r)
{
T[i].sum=f[l];
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
T[i].sum=T[ls].sum+T[rs].sum;
}
inline void push_down(int i)
{
LL lazy=T[i].lazy;
T[ls].lazy=lazy; T[rs].lazy=lazy;
T[ls].sum=(T[ls].r-T[ls].l+1)*lazy;
T[rs].sum=(T[rs].r-T[rs].l+1)*lazy;
T[i].lazy=0;
}
void update(int i,int l,int r,LL x)
{
if (l>r) return;
if (T[i].l==l&&T[i].r==r)
{
T[i].sum=(T[i].r-T[i].l+1)*x;
T[i].lazy=x; return;
}
if (T[i].lazy) push_down(i);
int mid=(T[i].l+T[i].r)>>1;
if (mid>=r) update(ls,l,r,x);
else if (mid<l) update(rs,l,r,x);
else
{
update(ls,l,mid,x);
update(rs,mid+1,r,x);
}
T[i].sum=T[ls].sum+T[rs].sum;
}
} seg2;
int main()
{
sf(n); sf(m);
for(int i=1;i<=n;i++) sf(d[i]);
seg1.build(1,1,n);
for(int i=1;i<=n;i++)
mx1[i]=max(mx1[i-1],d[i]);
for(int i=n;i>=1;i--)
mx2[i]=max(mx2[i+1],d[i]);
for(int i=1;i<=n;i++)
f[i]=max(d[i],min(mx1[i-1],mx2[i+1]));
seg2.build(1,1,n);
while(m--)
{
LL op,l,r,v;
sf(op);
if (op==1) printf("%lld\n",seg2.T[1].sum);
else if (op==2)
{
sc(l,r,v);
int ll=seg1.query1(1,1,l-1,v);
int rr=seg1.query2(1,r+1,n,v);
if (ll) ll=l; else ll=seg1.query1(1,l,r,v);
if (rr) rr=r; else rr=seg1.query2(1,l,r,v);
if (ll*rr==0) puts("0"); else printf("%d\n",rr-ll+1);
} else
{
sf(l); sf(v);
seg1.update(1,l,v); d[l]+=v;
int ll=seg1.query2(1,1,l-1,d[l]);
int rr=seg1.query1(1,l+1,n,d[l]);
if (ll&&rr) continue;
seg2.update(1,l,l,d[l]);
if (ll) seg2.update(1,ll+1,l-1,d[l]);
else if (l!=1)
{
P tmp=seg1.getmax2(1,1,l-1);
seg2.update(1,tmp.second,l-1,tmp.first);
}
if (rr) seg2.update(1,l+1,rr-1,d[l]);
else if (l!=1)
{
P tmp=seg1.getmax1(1,l+1,n);
seg2.update(1,l+1,tmp.second,tmp.first);
}
}
}
return 0;
}