题意:给定n(n<=500000)个数,A[1],A[2],...,A[n]。求所有子区间的 (最大值*最小值*长度)之和,对 10^9 取余数。
思路一:暴力 时间复杂度O(n^3) 超时
枚举所有子区间,然后遍历一遍来求最大值和最小值,然后累加答案。
思路二:稍加优化 时间复杂度O(n^2) 超时
枚举R,然后对于每个L<=R,用Max[L]表示区间[L,R]的最大值,Min[L]表示区间[L..R]的最小值。
在R增加1之后,每个L<=旧R的Max[L]和Min[L]只需要用A[R]去更新就行了。
代码:
//Simple solve
int Min[maxn],Max[maxn];
int SimpleSolve(){
LL ANS=0;
for(int R=1;R <=n;++R){
Min[R]=Max[R]=A[R];
for(int L=1;L <= R;++L){
//更新Min[L]和Max[L]
Min[L]=min(Min[L],A[R]);
Max[L]=max(Max[L],A[R]);
//累加区间[L,R]的答案
ANS=(ANS+(LL)Min[L]*Max[L]%MOD*(R-L+1))%MOD;
}
}
return ANS;
}
思路三是思路二的线段树优化,所以在看思路三之前,请确保看懂了思路二。
思路二中,对于每个R,用A[R]更新了[1..R]的最大值和最小值,然后累加了[1..R]到R的答案。
如果更新和求和都使用线段树的话,时间复杂度就变成了O(n*log(n))。
首先,如何用线段树维护最大值。
对于每个R,要将Max[1..R]的数组中,所有小于A[R]的都更新成A[R]。
注意到Max[1],Max[2],...,Max[R]是单调非增的数列。
(Max[1]代表区间[1..R]的最大值,Max[2]代表区间[2..R]的最大值,明显Max[1]>=Max[2].)
所以,实际上,从某下标L开始,Max[L..R]都小于A[R],于是变成了线段树的区间修改:将[L..R]的数变成A[R]。
那么,怎么找到L值呢?在线段树的节点上多维护一个最大值的最大值,然后就可以判断了,具体看代码。
最小值同理。
然后来谈一谈线段树每个节点需要的变量,以下用m表示最小值,用M表示最大值,用L表示长度。
变量分为:标记量和统计量。
先来谈标记量,需要最小值标记,最大值标记,和长度标记,记做m,M,L。
毕竟是区间修改,所以需要三种标记。
统计量:
然后明显要有(最大值*最小值*长度)的和,记做 smML。
在修改了最大值或最小值或长度时,要能够直接更新smML这个变量,
所以需要记录(最大值*最小值)的和,(最大值*长度)的和,(最小值*长度)的和,分别记做smM,sML,smL。
然后,在修改了最大值或最小值或长度时,要能够直接更新smM,sML,smL这三个变量,
需要记录最小值之和,最大值之和,长度之和,分别记做sm,sM,sL。
于是,s开头的求和的统计量需要7个。
在增加一个线段树区间的长度时,要更新sL,需要sL加上该区间的长度,所以用变量n表示该线段树区间的长度。
为了在更新最大值的时候可以找到左边界,需要记录最大值的最大值,记为MM。
同理,记录最小值的最小值,记为mm。
于是,线段树的一个节点,需要3个标记量 和 10个统计量。
节点的定义见下面代码:
//Segment Tree Node
struct Node{
int m,M,L;//min max Len 3个标记量
int mm,MM,n;//min of min,max of max , number of intervals
int sm,sM,sL,smM,smL,sML,smML;//sum of products
Node(){m=M=L=0;}
Node operator+(const Node &B){//节点的统计量的更新
Node &A = *this,C;
C.mm = min(A.mm,B.mm);
C.MM = max(A.MM,B.MM);
C.n = A.n + B.n;
C.sm = (A.sm + B.sm ) % MOD;
C.sM = (A.sM + B.sM ) % MOD;
C.sL = (A.sL + B.sL ) % MOD;
C.smM = (A.smM + B.smM ) % MOD;
C.smL = (A.smL + B.smL ) % MOD;
C.sML = (A.sML + B.sML ) % MOD;
C.smML = (A.smML + B.smML) % MOD;
return C;
}
void SetMax(int Max){//将该区间的最大值改为Max,修改最大值标记,以及对应的统计量
M = MM = Max;
sM = (LL)n * M % MOD;// 最大值的和 = 数量 * 最大值
smM = (LL)sm * M % MOD;//(最小值*最大值)的和 = 最小值的和 * 最大值
sML = (LL)sL * M % MOD;//(最大值*长度)的和 = 长度的和 * 最大值
smML = (LL)smL * M % MOD;//(最大值*最小值*长度)的和 = (最小值*长度)的和 * 最大值
}
void SetMin(int Min){//将该区间的最小值改为Min,修改最小值标记,以及对应的统计量
m = mm = Min;
sm = (LL)n * m % MOD;
smM = (LL)sM * m % MOD;
smL = (LL)sL * m % MOD;
smML = (LL)sML * m % MOD;
}
void AddLen(LL k){//该区间的长度增加k
L += k;
sL = (sL + k*n )%MOD ;
smL = (smL + k*sm )%MOD ;
sML = (sML + k*sM )%MOD ;
smML = (smML + k*smM)%MOD;
}
void SetValue(LL V){//设置叶节点的值
sL=n=1;
smL=sML=sm=sM=mm=MM=V;
smML=smM=V*V%MOD;
}
};
还剩下最大值的更新的左端点怎么找的问题:
首先,要在[1..R]这个区间中,将所有的比A[R]小的值变成A[R]。
第一部分:常规的线段树区间判断,可以得到所有在[1..R]之内的区间。
第二部分:如果本区间的最大值小于等于V,直接将整个区间的最大值设置为V即可。
如果不是叶节点,那么进行递归调用,右区间一定要递归调用,左区间根据条件。
如果右区间的最大值小于等于V,那么左侧可能也有要更新的区间。所以要递归调用左侧。
具体见UpdateMax函数:
void UpdateMax(int X,int V,int l,int r,int rt){//[1,X]
int L = rt << 1 , R = rt << 1 | 1 , m = (l + r) >> 1;
if(r <= X){//第二部分,得到了[1..X]的区间之后
if(D[rt].MM <= V){//如果本区间的最大值小于等于V,直接将本区间的最大值设置为V
D[rt].SetMax(V);
return;
}
if(l==r) return;//如果是叶节点,直接返回
PushDown(rt);
UpdateMax(X,V,rs);//更新右侧
//如果右侧的最大值大于V,那么左侧不可能有需要更新的值,所以不需要递归左侧
//否则,需要递归左侧
if(D[R].MM <= V) UpdateMax(X,V,ls);
PushUp(rt);
return;
}
//第一部分:常规线段树区间判断,可以得到所有[1..X]之内的区间
PushDown(rt);
UpdateMax(X,V,ls);
if(X > m) UpdateMax(X,V,rs);
PushUp(rt);
}
最后总结一下:
空间的问题:50万数据量,需要的线段树元素个数是1048576个,要是按一般的做法,直接四倍的话,就超出空间范围了。
时间的问题:开始写的是,先用线段树搜索出更新最大值的左侧下标L,再区间修改[L,R],然后超时了。
后来改成了在修改的时候顺便寻找修改边界,就快了许多。
----------------------------------------------------------------------------- 分割线 ------------------------------------------------------------------
上面说的方法,用时9.4秒。经过三个优化之后,可以达到1.5秒。
优化一:可以发现,对于每个R,如果A[R]>A[R-1]那么只需要更新最大值数组,最小值都不需要更新。节省了一半的线段树操作。
优化二:对于L值,前面的做法是把它当做数据来维护,其实不需要。可以直接在Query函数中计算,省掉了更新长度的操作,以及长度的懒惰标记。
优化三:前面的做法是在更新最大值的时候,利用统计量MM来找到左边界。其实可以用单调栈直接维护左边界。省去了左边界判断时间,以及mm,MM两个变量。
第一份代码如下:
/*
Problem 1618 - Magic Array
56520KB 9420ms
*/
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define LL long long
#define MOD 1000000000
#define maxn 500007
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
using namespace std;
//Input
int n,A[maxn];
//Segment Tree Node
struct Node{
int m,M,L;//min max Len
int mm,MM,n;//min of min,max of max , number of intervals
int sm,sM,sL,smM,smL,sML,smML;//sum of products
Node(){m=M=L=0;}
Node operator+(const Node &B){
Node &A = *this,C;
C.mm = min(A.mm,B.mm);
C.MM = max(A.MM,B.MM);
C.n = A.n + B.n;
C.sm = (A.sm + B.sm ) % MOD;
C.sM = (A.sM + B.sM ) % MOD;
C.sL = (A.sL + B.sL ) % MOD;
C.smM = (A.smM + B.smM ) % MOD;
C.smL = (A.smL + B.smL ) % MOD;
C.sML = (A.sML + B.sML ) % MOD;
C.smML = (A.smML + B.smML) % MOD;
return C;
}
void SetMax(int Max){
M = MM = Max;
sM = (LL)n * M % MOD;
smM = (LL)sm * M % MOD;
sML = (LL)sL * M % MOD;
smML = (LL)smL * M % MOD;
}
void SetMin(int Min){
m = mm = Min;
sm = (LL)n * m % MOD;
smM = (LL)sM * m % MOD;
smL = (LL)sL * m % MOD;
smML = (LL)sML * m % MOD;
}
void AddLen(LL k){
L += k;
sL = (sL + k*n )%MOD ;
smL = (smL + k*sm )%MOD ;
sML = (sML + k*sM )%MOD ;
smML = (smML + k*smM)%MOD;
}
void SetValue(LL V){
sL=n=1;
smL=sML=sm=sM=mm=MM=V;
smML=smM=V*V%MOD;
}
}D[1048576];
void PushUp(int rt){D[rt] = D[rt<<1] + D[rt<<1|1];}
void PushDown(int rt){//Push down three marks
int L = rt << 1 , R = rt << 1 | 1;
if(D[rt].M){
D[L].SetMax(D[rt].M);
D[R].SetMax(D[rt].M);
D[rt].M=0;
}
if(D[rt].m){
D[L].SetMin(D[rt].m);
D[R].SetMin(D[rt].m);
D[rt].m=0;
}
if(D[rt].L){
D[L].AddLen(D[rt].L);
D[R].AddLen(D[rt].L);
D[rt].L=0;
}
}
void Build(int l,int r,int rt){
if(l==r){
D[rt].SetValue(A[l]);
return;
}
int m=(l+r)>>1;
Build(ls);
Build(rs);
PushUp(rt);
}
void UpdateMax(int X,int V,int l,int r,int rt){//[1,X]
int L = rt << 1 , R = rt << 1 | 1 , m = (l + r) >> 1;
if(r <= X){
if(D[rt].MM <= V){
D[rt].SetMax(V);
return;
}
if(l==r) return;
PushDown(rt);
UpdateMax(X,V,rs);
if(D[R].MM <= V) UpdateMax(X,V,ls);
PushUp(rt);
return;
}
PushDown(rt);
UpdateMax(X,V,ls);
if(X > m) UpdateMax(X,V,rs);
PushUp(rt);
}
void UpdateMin(int X,int V,int l,int r,int rt){//[1,X]
int L = rt << 1 , R = rt << 1 | 1 , m = (l + r) >> 1;
if(r <= X){
if(D[rt].mm >= V){
D[rt].SetMin(V);
return;
}
if(l==r) return;
PushDown(rt);
UpdateMin(X,V,rs);
if(D[R].mm >= V) UpdateMin(X,V,ls);
PushUp(rt);
return;
}
PushDown(rt);
UpdateMin(X,V,ls);
if(X > m) UpdateMin(X,V,rs);
PushUp(rt);
}
void UpdateLen(int X,int l,int r,int rt){//[1,X]
if(r <= X){
D[rt].AddLen(1);
return;
}
PushDown(rt);
int m = (l + r) >> 1;
UpdateLen(X,ls);
if(X > m) UpdateLen(X,rs);
PushUp(rt);
}
LL Query(int X,int l,int r,int rt){//求和
if(r <= X){
return D[rt].smML;
}
PushDown(rt);
int m = (l + r) >> 1;
LL ANS = Query(X,ls);
if(X > m) ANS = (ANS + Query(X,rs)) % MOD;
return ANS;
}
int main(void)
{
while(~scanf("%d",&n)){
for(int i=1;i<=n;++i) scanf("%d",&A[i]);
Build(1,n,1);
LL ANS = Query(1,1,n,1);
for(int R=2;R <= n;++R){
UpdateMax(R,A[R],1,n,1);//更新最大值
UpdateMin(R,A[R],1,n,1);//更新最小值
UpdateLen(R-1,1,n,1);//更新长度
ANS = (ANS + Query(R,1,n,1)) % MOD;//累加答案
}
printf("%d\n",(int)ANS);
}
return 0;
}
优化后代码如下:
/*
Problem 1618 - Magic Array
Memory: 44260KB Time: 1500ms
*/
#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#define LL long long
#define MOD 1000000000
#define maxn 500007
#define ls l,m,rt<<1
#define rs m+1,r,rt<<1|1
using namespace std;
//Input
int n,A[maxn];
int Min[maxn],IMin;
int Max[maxn],IMax;
//Segment Tree Node
struct Node{
int m,M;//min max
int n;//number of intervals
int sm,sM,sL,smM,smL,sML,smML;//sum of products
Node(){m=M=0;}
Node operator+(const Node &B)const{
const Node &A = *this;
Node C;
C.n = A.n + B.n;
C.sm = (A.sm + B.sm ) % MOD;
C.sM = (A.sM + B.sM ) % MOD;
C.smM = (A.smM + B.smM ) % MOD;
C.sL = (A.sL + (LL)A.n * B.n + B.sL ) % MOD;
C.smL = (A.smL + (LL)A.sm * B.n + B.smL ) % MOD;
C.sML = (A.sML + (LL)A.sM * B.n + B.sML ) % MOD;
C.smML = (A.smML + (LL)A.smM * B.n + B.smML) % MOD;
return C;
}
void SetMax(int Max){
M = Max;
sM = (LL)n * M % MOD;
smM = (LL)sm * M % MOD;
sML = (LL)sL * M % MOD;
smML = (LL)smL * M % MOD;
}
void SetMin(int Min){
m = Min;
sm = (LL)n * m % MOD;
smM = (LL)sM * m % MOD;
smL = (LL)sL * m % MOD;
smML = (LL)sML * m % MOD;
}
void SetValue(LL V){
sL=n=1;
smL=sML=sm=sM=V;
smML=smM=V*V%MOD;
}
}D[1048576];
void PushUp(int rt){D[rt] = D[rt<<1] + D[rt<<1|1];}
void PushDown(int rt){//Push down three marks
int L = rt << 1 , R = rt << 1 | 1;
if(D[rt].M){
D[L].SetMax(D[rt].M);
D[R].SetMax(D[rt].M);
D[rt].M=0;
}
if(D[rt].m){
D[L].SetMin(D[rt].m);
D[R].SetMin(D[rt].m);
D[rt].m=0;
}
}
void Build(int l,int r,int rt){
if(l==r){
D[rt].SetValue(A[l]);
return;
}
int m=(l+r)>>1;
Build(ls);
Build(rs);
PushUp(rt);
}
void UpdateMax(int L,int R,int V,int l,int r,int rt){
if(L <= l && r <= R){
D[rt].SetMax(V);
return;
}
PushDown(rt);
int m = (l + r) >> 1;
if(L <= m) UpdateMax(L,R,V,ls);
if(R > m) UpdateMax(L,R,V,rs);
PushUp(rt);
}
void UpdateMin(int L,int R,int V,int l,int r,int rt){
if(L <= l && r <= R){
D[rt].SetMin(V);
return;
}
PushDown(rt);
int m = (l + r) >> 1;
if(L <= m) UpdateMin(L,R,V,ls);
if(R > m) UpdateMin(L,R,V,rs);
PushUp(rt);
}
Node Query(int X,int l,int r,int rt){
if(r <= X){
return D[rt];
}
PushDown(rt);
int m = (l + r) >> 1;
Node ANS = Query(X,ls);
if(X > m) ANS = ANS + Query(X,rs);
return ANS;
}
int main(void)
{
while(~scanf("%d",&n)){
for(int i=1;i<=n;++i) scanf("%d",&A[i]);
Build(1,n,1);
Min[0]=Max[0]=IMin=IMax=0;
LL ANS = 0;
for(int R=1;R <= n;++R){
while(IMin && A[R]<=A[Min[IMin]]) --IMin;
Min[++IMin]=R;
while(IMax && A[R]>=A[Max[IMax]]) --IMax;
Max[++IMax]=R;
if(A[R]>A[R-1]) UpdateMax(Max[IMax-1]+1,R,A[R],1,n,1);
else UpdateMin(Min[IMin-1]+1,R,A[R],1,n,1);
ANS = (ANS + Query(R,1,n,1).smML) % MOD;
}
printf("%d\n",(int)ANS);
}
return 0;
}