题意:给定$t_{1\cdots n}$,你要确定一个$01$数组$c_{1\cdots n}$使得$\left(\sum\limits_{1\leq i\leq j\leq n}\prod\limits_{k=i}^jc_k\right)-\sum\limits_{i=1}^nc_it_i$最大,现在有$m$次询问,形如“假如把$t_p$改为$x$,答案是多少?”
相当于选择一些互不相邻的区间,每个长度为$k$的区间有$\frac{k(k+1)}2$的贡献,一个位置$i$被选中则要付出$t_i$的代价
首先看没有询问怎么做,设$f_i$表示前$i$位的答案,$s$为$t$的前缀和,则$f_i=\max\left(f_{i-1},\max\limits_{0\leq j\lt i}f_j+\frac{(i-j)(i-j+1)}2-(s_i-s_j)\right)$
把后面的$\max$里的式子拆开:$\left(-ji+f_j+\frac{j^2-j}2+s_j\right)+\frac{i^2+i}2-s_i$,如果固定$j$,那么括号里是关于$i$的一次函数,整个DP的过程相当于不断地加直线和询问所有直线在$x=i$处的最大值,因为加入的直线斜率越来越小,询问的$x$越来越大,所以用栈维护直线形成的下凸壳即可
现在考虑有询问怎么做,如果一个方案选择的区间不包含$p$,那么问题可以拆成在$[1,p-1],[p+1,n]$中的子问题,把上面的DP正反各做一次得到正的$f_i$和反的$g_i$,这一部分的答案就是$f_{p-1}+g_{p+1}$
如果一个方案选择的区间包含$p$,我们想要求出$h_i$表示选择的某个区间包含$i$的最大答案
最朴素的想法是枚举每一个区间$[l,r]$,用$f_{l-1}+g_{r+1}+\frac{(r-l+1)(r-l+2)}2-\sum\limits_{i=l}^rt_i$更新$h_{l\cdots r}$,这样做太慢,考虑分治
假设当前分治区间为$[l,r]$,我们想要算出所有跨越$[mid,mid+1]$的区间对答案的贡献
枚举右端点$i\in[mid+1,r]$并用$g_{i+1}+\max\limits_{l-1\leq j\lt mid}f_j+\frac{(i-j)(i-j+1)}2-(s_i-s_j)$更新$h_{mid+1\cdots i}$,枚举左端点的左边$i\in[l-1,mid-1]$并用$f_i+\max\limits_{mid\lt j\leq r}g_{j+1}+\frac{(j-i)(j-i+1)}2-(s_j-s_i)$更新$h_{i+1\cdots mid}$,这两个DP可以用和前面一样的方法来完成(先加完直线再单调地查询最大值)
预处理完$f,g,h$后,答案就是$\max(f_{p-1}+g_{p+1},h_p+t_p-x)$
总时间复杂度$O(n\log n)$
#include<stdio.h> #include<algorithm> using namespace std; typedef long long ll; typedef double du; const ll inf=9223372036854775807ll; struct line{ ll k,b;//y=kx+b line(ll k=0,ll b=0):k(k),b(b){} ll get(ll x){return k*x+b;} }st[300010],u; int tp; du its(line a,line b){ return(b.b-a.b)/du(a.k-b.k); } int t[300010],n; ll s[300010],f[300010],g[300010]; void dp(ll*f){ int i; for(i=1;i<=n;i++)s[i]=s[i-1]+t[i]; tp=1; st[1]=line(0,0); for(i=1;i<=n;i++){ while(tp>1&&st[tp].get(i)<=st[tp-1].get(i))tp--; f[i]=max(st[tp].get(i)+((ll)i*i+i)/2-s[i],f[i-1]); u=line(-i,f[i]+((ll)i*i-i)/2+s[i]); while(tp>1&&its(u,st[tp-1])>=its(st[tp],st[tp-1]))tp--; st[++tp]=u; } } ll h[300010]; void solve(int l,int r){ if(l==r){ h[l]=1ll-t[l]; return; } int mid,i,now; ll mx; mid=(l+r)>>1; solve(l,mid); solve(mid+1,r); tp=0; for(i=l-1;i<mid;i++){ u=line(-i,f[i]+((ll)i*i-i)/2+s[i]); while(tp>1&&its(u,st[tp-1])>=its(st[tp],st[tp-1]))tp--; st[++tp]=u; } now=1; mx=-inf; for(i=r;i>mid;i--){ while(now<tp&&its(st[now],st[now+1])>=i)now++; mx=max(mx,st[now].get(i)+((ll)i*i+i)/2-s[i]+g[i+1]); h[i]=max(h[i],mx); } tp=0; for(i=mid+1;i<=r;i++){ u=line(-i,g[i+1]+((ll)i*i+i)/2-s[i]); while(tp>1&&its(u,st[tp-1])>=its(st[tp],st[tp-1]))tp--; st[++tp]=u; } now=tp; mx=-inf; for(i=l-1;i<mid;i++){ while(now>1&&its(st[now],st[now-1])<=i)now--; mx=max(mx,st[now].get(i)+((ll)i*i-i)/2+s[i]+f[i]); h[i+1]=max(h[i+1],mx); } } int main(){ int i,m,x,y; scanf("%d",&n); for(i=1;i<=n;i++)scanf("%d",t+i); dp(f); reverse(t+1,t+n+1); dp(g); reverse(t+1,t+n+1); reverse(g+1,g+n+1); for(i=1;i<=n;i++)s[i]=s[i-1]+t[i]; solve(1,n); scanf("%d",&m); while(m--){ scanf("%d%d",&x,&y); printf("%lld\n",max(f[x-1]+g[x+1],h[x]+t[x]-y)); } }