题目
给出N个数,在里面选出不超过K段连续的子序列,使其两两不相交,求总和的最大值(可以一段都不选)
数据范围
N,K<= 100000
对于一个数a满足 -1000000000 <= a <= 100000000
题解
首先看到这道题很容易想到是dp,然后再加上一个优化
可是这里的N,K太大O(NK)是会超时的
所以换方法,然后用了一种不知道为什么的算法:
用线段树维护,然后每次选出最大的一段子序列,加入答案中,再将这段区间选的数都乘上-1,一直重复K次或者是选到了最大的子序列是负数时,就退出
好像是用到了费用流的思想,然而我并想不出来
所以只会简单的证明一下
证明:每次取得的区间只有两种可能:
1.被之前选的区间之内
2.选择一个还未选择过的区间(这里指的是从左端点到右端点一直没有被选过)
证明
首先一段区间两端一定都是正数(对于当前来说)
对于当前情况
如果有一个区间i与之前选过的区间j相交,那么有一个端点是在之前选过的区间之内,且这个端点的值一定是正数
也不难发现,与之前选过的区间相交的这一段区间的和一定是大于0的,那就是说之前这个端点的值是负数,且相交区间和的应该是负数,那么在第一次选这个j的时候时就可以选择不要相交的区间,会使和最大,反证法可得命题不成立所以不会相交
同时i包含j也可以用反证法证明
代码
#include <iostream>//对于这道题,首先是要维护区间最大值最小值,同时也要维护从最左端开始的最大
#include <cstdio>//与最小区间,右端点也要这样,才可以进行转移
#include <cstring>
#include <cmath>
#include <vector>
using namespace std;
#define ll long long
const int MAXN = 1e5 + 3;
struct edge{
int l ,r;
ll sum;
edge(){}
edge( int L , int R , ll S ){
l = L;r = R;sum = S;
}
};
struct node{
int l , r;
edge zsum , lsum ,rsum;
edge zmin , lmin , rmin;
int lazy;
}tre[MAXN*4];
int n , K;
ll a[MAXN];
edge min_( edge a , edge b ){
edge ans;
if( a.sum >= b.sum )
ans = a;
else
ans = b;
return ans;
}
void push_( int i ){
tre[i].zsum = min_( tre[i*2].zsum , tre[i*2+1].zsum );
if( tre[i*2].rsum.sum + tre[i*2+1].lsum.sum > tre[i].zsum.sum ){
tre[i].zsum.sum= tre[i*2].rsum.sum + tre[i*2+1].lsum.sum;
tre[i].zsum.l = tre[i*2].rsum.l , tre[i].zsum.r = tre[i*2+1].lsum.r;
}
tre[i].lsum = tre[i*2].lsum;
if( tre[i*2].lsum.r - tre[i*2].lsum.l + 1 == tre[i*2].r - tre[i*2].l + 1 ){
if( tre[i].lsum.sum < tre[i*2].lsum.sum + tre[i*2+1].lsum.sum )
tre[i].lsum.sum = tre[i*2].lsum.sum + tre[i*2+1].lsum.sum , tre[i].lsum.l = tre[i].l , tre[i].lsum.r = tre[i*2+1].lsum.r;
}
tre[i].rsum = tre[i*2+1].rsum;
if( tre[i*2+1].rsum.r - tre[i*2+1].rsum.l + 1 == tre[i*2+1].r - tre[i*2+1].l + 1 ){
if( tre[i].rsum.sum < tre[i*2].rsum.sum + tre[i*2+1].rsum.sum )
tre[i].rsum.sum = tre[i*2].rsum.sum + tre[i*2+1].rsum.sum , tre[i].rsum.l = tre[i*2].rsum.l , tre[i].rsum.r =tre[i].r ;
}//
if( tre[i*2].zmin.sum <= tre[i*2+1].zmin.sum )
tre[i].zmin = tre[i*2].zmin;
else
tre[i].zmin = tre[i*2+1].zmin;
if( tre[i*2].rmin.sum + tre[i*2+1].lmin.sum < tre[i].zmin.sum ){
tre[i].zmin.sum = tre[i*2].rmin.sum + tre[i*2+1].lmin.sum;
tre[i].zmin.l = tre[i*2].rmin.l , tre[i].zmin.r = tre[i*2+1].lmin.r;
}
tre[i].lmin = tre[i*2].lmin;
if( tre[i*2].lmin.r - tre[i*2].lmin.l + 1 == tre[i*2].r - tre[i*2].l + 1 ){
if( tre[i].lmin.sum > tre[i*2].lmin.sum + tre[i*2+1].lmin.sum )
tre[i].lmin.sum = tre[i*2].lmin.sum + tre[i*2+1].lmin.sum , tre[i].lmin.l = tre[i].l , tre[i].lmin.r = tre[i*2+1].lmin.r;
}
tre[i].rmin = tre[i*2+1].rmin;
if( tre[i*2+1].rmin.r - tre[i*2+1].rmin.l + 1 == tre[i*2+1].r - tre[i*2+1].l + 1 ){
if( tre[i].rmin.sum > tre[i*2].rmin.sum + tre[i*2+1].rmin.sum )
tre[i].rmin.sum = tre[i*2].rmin.sum + tre[i*2+1].rmin.sum , tre[i].rmin.l = tre[i*2].rmin.l , tre[i].rmin.r =tre[i].r ;
}
}
void build( int i , int l , int r ){
tre[i].l = l , tre[i].r = r;tre[i].lazy = 1;
if( l == r ){
tre[i].zsum =edge( l , r , a[l] );tre[i].lsum =edge( l , r , a[l] );tre[i].rsum =edge( l , r , a[l] );
tre[i].zmin =edge( l , r , a[l] );tre[i].lmin =edge( l , r , a[l] );tre[i].rmin =edge( l , r , a[l] );
return ;
}
int mid = ( l + r ) / 2;
build( i * 2 , l , mid );
build( i * 2 + 1 , mid + 1, r );
push_( i );
}
void pushdown( int i ){
if( tre[i].lazy == -1 ){
tre[i*2].zsum.sum *= -1;tre[i*2].zmin.sum *= -1;
tre[i*2].lsum.sum *= -1;tre[i*2].lmin.sum *= -1;
tre[i*2].rsum.sum *= -1;tre[i*2].rmin.sum *= -1;
swap( tre[i*2].zsum , tre[i*2].zmin );
swap( tre[i*2].lsum , tre[i*2].lmin );
swap( tre[i*2].rsum , tre[i*2].rmin );
tre[i*2+1].zsum.sum *= -1;tre[i*2+1].zmin.sum *= -1;
tre[i*2+1].lsum.sum *= -1;tre[i*2+1].lmin.sum *= -1;
tre[i*2+1].rsum.sum *= -1;tre[i*2+1].rmin.sum *= -1;
swap( tre[i*2+1].zsum , tre[i*2+1].zmin );
swap( tre[i*2+1].lsum , tre[i*2+1].lmin );
swap( tre[i*2+1].rsum , tre[i*2+1].rmin );
tre[i].lazy = 1;
tre[i*2].lazy *= -1;tre[i*2+1].lazy *= -1;
}
}
void change( int i , int l , int r ){
if( tre[i].l > r || tre[i].r < l )return ;
if( l <= tre[i].l && r >= tre[i].r ){
tre[i].lazy *= -1;
tre[i].zsum.sum *= -1;tre[i].zmin.sum *= -1;
tre[i].lsum.sum *= -1;tre[i].lmin.sum *= -1;
tre[i].rsum.sum *= -1;tre[i].rmin.sum *= -1;
swap( tre[i].zsum , tre[i].zmin );
swap( tre[i].lsum , tre[i].lmin );
swap( tre[i].rsum , tre[i].rmin );
return ;
}
pushdown( i );
change( i * 2 , l , r );
change( i * 2 + 1 , l , r );
push_( i );
}
void read( ll &x ){
char s = getchar();int f=1;x= 0 ;
while( s < '0' || s > '9' ){
if( s == '-' )
f = -1;
s = getchar();
}
while( s >= '0' && s <= '9' ){
x = x * 10 + s - '0';
s = getchar();
}
x *= f;
}
int main()
{
freopen( "maxksum.in" , "r" , stdin );
freopen( "maxksum.out" , "w" , stdout );
scanf( "%d%d" , &n , &K );
for( int i = 1 ; i <= n ; i ++ ){
read( a[i] );
}
build( 1 , 1 , n );
ll ans = 0;
while( K -- ){
ll tot = tre[1].zsum.sum;
if( tot <= 0 ) break;
ans += tot;
change( 1 , tre[1].zsum.l , tre[1].zsum.r );
}
printf( "%lld" , ans );
return 0;
}