题目链接:点击查看
题目大意:给出一个长度为 n 的数列 a ,规定当 a[ i ] == i 时,位置 i 可以被删除掉,后面位置会合并上来,现在需要回答 m 次询问,每次询问会问禁用掉后面 x 个数字和后面的 y 个数字后,最多可以删除掉多少个数字,每个询问都相互独立,举个例子比较好看:
这是样例下面的样例解释,应该不难看懂
题目分析:首先 x 和 y 可以转换成 l 和 r ,也就是询问在某段区间 [ l , r ] 中最多可以删除掉多少个数,其次因为每个位置的元素都会与其位置比较,所以不妨直接做差,将其转换为与 0 进行比较,也就是将 a[ i ] 用 i - a[ i ] 代替,此时按照正负分三种情况讨论:
- i - a[ i ] == 0:当前位置可以直接删除
- i - a[ i ] < 0:当前位置永远不可能被删除
- i - a[ i ] > 0:当前位置之前删除掉 i - a[ i ] 个数字后,当前位置可以被删除
注意每次删除掉位置 i 后,下标 [ i + 1 , n ] 的位置由于会和 [ 1 , i - 1 ] 进行合并,所以 [ i + 1 , n ] 的 i - a[ i ] 会减少一
以上讲的是暴力模拟的做法,但显然不能直接暴力模拟去做,时间复杂度会退化为 n^2logn
接下来考虑递推,先不考虑左区间的限制,设 f( r ) 为前缀区间 [ 1 , r ] 内最多可以删除多少个数,递推的转移也非常简单:
- 如果 i - a[ i ] < 0 或 i - a[ i ] > f( i - 1 ):f( i ) = f( i - 1 )
- 否则 f( i ) = f( i - 1 ) + 1
仅对于前缀来说,这个递推式还是蛮好想的,接下来我们需要用相同的思想扩展到每个前缀区间的后缀来表示所有的子区间,也就是在固定 r 后,将所维护的区间扩展到:s[ i ] = 区间 [ i , r ] 内最多可以删除多少个数,i ∈ [ 1 , r ] 时答案正常计算,i > r 时 s[ i ] = 0,首先比较明显的就是,随着 r 的递增,s 数组的每个元素相对于之前,迭代后一定是不减的,因为假如 [ l , r ] 内最多可以删除掉 x 个数,现在任取一个 r2 > r ,那么在区间 [ l , r2 ] 中,完全可以先删除掉 [ l , r ] 内的数,然后顺便删除掉 [ r + 1 , r2 ] 中的数,所以 s 数组的每个元素迭代后一定是不减的
再观察一下当 r 确定后,s[ i ] 的单调性,上面那一段说的是 s 中的每个元素(相互独立)随着 r 的迭代呈现的趋势,注意区分,随着 i 的增长,s[ i ] 所表示的区间长度也随之减少,简单来说,任取 l1 < l2 < r ,位于区间 [ l2 , r ] 内的某个元素,可能在区间 [ l2 , r ] 内无法删除掉,但是在区间 [ l1 , r ] 内就可以删除掉了,感性理解一下当 r 确定后 s[ i ] 是递减的
到此总结一下,s[ i ] 代表的是当 r 确定后,[ i , r ] 内最多可以删除多少个数,上面说的那么复杂,其实就是为了证明以下两点:
- 当 r 固定时,s[ i ] 呈非递增的趋势
- 随着 r 的迭代,s 中的每个元素(视为相互独立),迭代后呈非递减的趋势
接下来考虑当前加入第 i 个元素后的影响:
- 如果 a[ i ] < 0 或 i - a[ i ] > s[ 1 ],对应着上面的 “ i - a[ i ] < 0 或 i - a[ i ] > f( i - 1 ) ”,此时贡献不变,s 数组也不变
- 否则对于每个满足 s[ j ] >= i - a[ i ] 的 j 来说,都是可以顺便删掉第 i 个数的,故贡献 s[ j ] ++,其余位置不变
结合上面说的,当右端点 r 确定时,s[ i ] 呈非递增的趋势,所以我们可以找到最右边的一个 pos 满足 s[ pos ] >= i - a[ i ] ,这样显然有:s[ 1 ] >= s[ 2 ] >= ... >= s[ pos ] >= i - a[ i ] > s[ pos + 1 ] >= ... >= s[ n ]
根据上面的第二个结论,只需要将 s[ 1 ] ~ s[ pos ] 区间加一即可,至于寻找 pos 位置的过程,因为 s 数组满足单调性,故可以用二分去快速找到,说到这里,区间修改和单点查询,可以借助线段树来维护 s 数组的迭代,外层再套一层循环用来迭代右端点 r 即可,时间复杂度为 nlogn^2,需要先将所有的询问离线然后进行上述操作,最后统一输出即可
有点细节就是,为了方便处理,我将 i - a[ i ] < 0 的部分设置为了无穷大,可以少一些特判
2020.9.8更新:
昨晚上 rsb 学长提供了一种非常优秀的思路,可以将时间复杂度降低到 nlogn ,具体就是说,对于每个 i - a[ i ] 来说,如果 i - a[ i ] < 0 显然是永远无法删除的,不会对答案造成贡献,忽略即可,而对于每个大于等于 0 的 i - a[ i ] 来说,前面必须要删除掉 i - a[ i ] 个数后才可以将当前位置的数删掉,所以我们不妨枚举右端点的 r - a[ r ],找到其可行的左区间,也就是其前面必须至少要有 r - a[ r ] 个数才行,找到这个位置记为 l[ r ] 作为其左端点可行的最大位置,其意义就是,如果右端点为 r 时,左端点位于 [ 1 , l[ r ] ] 时,那么 a[ r ] 这个数会对其贡献加一,位于区间 [ l[ r ] + 1 , n ] 的话就不做贡献
再说一下确定 r - a[ r ] 后如何找到 l[ r ] ,因为前面需要有 r - a[ r ] 个数,设当前有 cnt 个数,也就是上文中的 cnt = f( r - 1 ),这样我们只需要在 f( r - 1 ) 内找到区间中第 r - a[ r ] 大的数即可,注意这里要求的是第 r - a[ r ] 大而不是第 r - a[ r ] 小
同样用线段树维护一下就好
代码:
二分+线段树
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int N=3e5+100;
int a[N],ans[N];
vector<pair<int,int>>q[N];
struct Node
{
int l,r,len;
LL sum,lazy;
}tree[N<<2];
void pushup(int k)
{
tree[k].sum=tree[k<<1].sum+tree[k<<1|1].sum;
}
void pushdown(int k)
{
if(tree[k].lazy)
{
LL lz=tree[k].lazy;
tree[k].lazy=0;
tree[k<<1].sum+=tree[k<<1].len*lz;
tree[k<<1|1].sum+=tree[k<<1|1].len*lz;
tree[k<<1].lazy+=lz;
tree[k<<1|1].lazy+=lz;
}
}
void build(int k,int l,int r)
{
tree[k].l=l;
tree[k].r=r;
tree[k].len=r-l+1;
tree[k].sum=tree[k].lazy=0;
if(l==r)
return;
int mid=l+r>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}
void update(int k,int l,int r)
{
if(l>r)
return;
if(tree[k].l>r||tree[k].r<l)
return;
if(tree[k].l>=l&&tree[k].r<=r)
{
tree[k].sum+=tree[k].len;
tree[k].lazy++;
return;
}
pushdown(k);
update(k<<1,l,r);
update(k<<1|1,l,r);
pushup(k);
}
LL query(int k,int pos)
{
if(tree[k].l==tree[k].r)
return tree[k].sum;
pushdown(k);
int mid=tree[k].l+tree[k].r>>1;
if(pos<=mid)
return query(k<<1,pos);
else
return query(k<<1|1,pos);
}
int get_pos(int i)
{
int l=1,r=i,ans=-1;
while(l<=r)
{
int mid=l+r>>1;
if(query(1,mid)>=a[i])
{
ans=mid;
l=mid+1;
}
else
r=mid-1;
}
return ans;
}
int main()
{
#ifndef ONLINE_JUDGE
// freopen("data.in.txt","r",stdin);
// freopen("data.out.txt","w",stdout);
#endif
// ios::sync_with_stdio(false);
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",a+i);
a[i]=i-a[i];
if(a[i]<0)
a[i]=inf;
}
for(int i=1;i<=m;i++)
{
int l,r;
scanf("%d%d",&l,&r);
l++;
r=n-r;
q[r].emplace_back(l,i);
}
build(1,1,n);
for(int i=1;i<=n;i++)//枚举右端点然后迭代
{
int pos=get_pos(i);//找到最大的l满足s[pos]>=a[i]
update(1,1,pos);
for(int j=0;j<q[i].size();j++)
ans[q[i][j].second]=query(1,q[i][j].first);
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}
线段树
#include<iostream>
#include<cstdio>
#include<string>
#include<ctime>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<stack>
#include<climits>
#include<queue>
#include<map>
#include<set>
#include<sstream>
#include<cassert>
#include<bitset>
using namespace std;
typedef long long LL;
typedef unsigned long long ull;
const int inf=0x3f3f3f3f;
const int N=3e5+100;
int a[N],ans[N];
vector<pair<int,int>>pos[N];
struct Node
{
int l,r;
int sum;
}tree[N<<2];
void build(int k,int l,int r)
{
tree[k].l=l;
tree[k].r=r;
tree[k].sum=0;
if(l==r)
return;
int mid=l+r>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
}
void update(int k,int pos)
{
if(tree[k].l==tree[k].r)
{
tree[k].sum++;
return;
}
int mid=tree[k].l+tree[k].r>>1;
if(pos<=mid)
update(k<<1,pos);
else
update(k<<1|1,pos);
tree[k].sum=tree[k<<1].sum+tree[k<<1|1].sum;
}
int query_k(int k,int x)//区间第x大
{
if(tree[k].l==tree[k].r)
return tree[k].l;
if(tree[k<<1|1].sum>=x)
return query_k(k<<1|1,x);
else
return query_k(k<<1,x-tree[k<<1|1].sum);
}
int query(int k,int l,int r)//区间和
{
if(tree[k].l>r||tree[k].r<l)
return 0;
if(tree[k].l>=l&&tree[k].r<=r)
return tree[k].sum;
return query(k<<1,l,r)+query(k<<1|1,l,r);
}
int main()
{
#ifndef ONLINE_JUDGE
// freopen("data.in.txt","r",stdin);
// freopen("data.out.txt","w",stdout);
#endif
// ios::sync_with_stdio(false);
int n,m;
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",a+i);
a[i]=i-a[i];
if(a[i]<0)
a[i]=inf;
}
for(int i=1;i<=m;i++)
{
int l,r;
scanf("%d%d",&l,&r);
l++;
r=n-r;
pos[r].emplace_back(l,i);
}
build(1,1,n);
int num=0;
for(int i=1;i<=n;i++)
{
if(a[i]==0)
{
update(1,i);
num++;
}
else if(num>=a[i])
{
update(1,query_k(1,a[i]));
num++;
}
for(pair<int,int>t:pos[i])
ans[t.second]=query(1,t.first,n);
}
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
return 0;
}