裁剪序列(单调队列优化dp)

原题链接:
裁剪序列

dp分析

dp分析
这道题数据范围是1e5

朴素版代码,基于最基本的理解
(但只能过6个数据, n三方的复杂度)

#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>

using namespace std;
typedef long long ll;

const int N = 1e5 +10;

int n;
int a[N];
ll m;
ll f[N], sum[N];

int main(){
    
    
	cin >> n >> m;
	for(int i = 1; i <= n; i ++){
    
    
	    scanf("%d", &a[i]);
	    sum[i] = sum[i-1] + a[i];
	}
	
	memset(f, 0x3f, sizeof(f));
	f[0] = 0;
	
	for(int i = 1; i <= n; i ++){
    
    
	    for(int j = 1; j <= i; j ++)
	    {
    
    
	        if(sum[i] - sum[j - 1] > m) continue;
	        int maxx = 0;
	        //每次循环都需要查找一个区间最大值
	        //这里可以用ST表进行O(1)优化
	        for(int k = j; k <= i; k ++)
	        {
    
    
	            maxx = max(maxx, a[k]);
	        }
	        f[i] = min(f[i], f[j-1] + maxx);
	    }
	}
	
	if(f[n] == 0) puts("-1");
	else
	printf("%lld\n", f[n]);
	
	return 0;
} 

加入了快读 和 ST表查区间最大值可以过11个样例,但时间复杂度仍是n方还是AC不了

#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>

using namespace std;
typedef long long ll;

const int N = 1e5 +10;

int n;
int a[N], p[N][30];
ll m;
ll f[N], sum[N];

inline int read()
{
    
    
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
    
    
        if(ch == '-')
        {
    
    
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
    
    
        x = x*10 + ch - 48;
        ch = getchar();
    }
    return x * f;
}

int main(){
    
    
	cin >> n >> m;
	for(int i = 1; i <= n; i ++){
    
    
	    a[i] = read();
	    p[i][0] = a[i];
	    sum[i] = sum[i-1] + a[i];
	}
	
	memset(f, 0x3f, sizeof(f));
	f[0] = 0;
	
	//实现ST表
	for(int j = 1; j <= log2(n); j ++)
	{
    
    
	    for(int i = 1; i + (1 << j) - 1 <= n; i++)
	    {
    
    
	        p[i][j] = max(p[i][j-1], p[i+(1 << (j-1))][j-1]);
	    }
	}
	
	for(int i = 1; i <= n; i ++){
    
    
	    for(int j = 1; j <= i; j ++)
	    {
    
    
	        if(sum[i] - sum[j - 1] > m) continue;
	        int len = i - j + 1;
	        int lo = log2(len);
	        int maxx = max(p[j][lo], p[i - (1 << lo) + 1][lo]);
	        
	        f[i] = min(f[i], f[j-1] + maxx);
	    }
	}
	
	if(f[n] == 0) puts("-1");
	else
	printf("%lld\n", f[n]);
	
	return 0;
} 
  • 接下来就是最难的地方了,做题需要具体问题具体分析,分析出来这个题独有的性质,然后再转变思路,因为上面的思路极限就是n^2了,我们还需要寻找新的优化!

单调队列维护可能更新f[i]的所有可能
时间复杂度最坏O(n^2)即原数列是单调递减的并且和不超过m

#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>

using namespace std;
typedef long long ll;

const int N = 1e5 +10;

int n;
int a[N], q[N];
ll m;
ll f[N];

inline int read()
{
    
    
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
    
    
        if(ch == '-')
        {
    
    
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
    
    
        x = x*10 + ch - 48;
        ch = getchar();
    }
    return x * f;
}

int main(){
    
    
	cin >> n >> m;
	for(int i = 1; i <= n; i ++){
    
    
	    a[i] = read();
	}
	
	memset(f, 0x3f, sizeof(f));
	f[0] = 0;
	
	//单调队列实现
	int hh = 0, tt = -1;
	ll cnt = 0;
	//因为a数组下标从1开始,所以先++t;
	q[++tt] = 0;
	for(int i = 1, j = 1; i <= n; i ++)
	{
    
    
		//i是前i个数, j是选择的要更新f[i]的f[j],
		//f[i]=max(f[i],f[j]+max(a[j+1...i])
		//  hh  tt  //代表单调队列的下标,他的值是a的下标
		cnt += a[i];//维护区间和小于m的(j...i)
		while(cnt > m) {
    
    
			cnt -= a[j];
			j ++;
		}
		//j在cnt<=m的最左边
		
		//这里是把q[hh]中,也就是对头超出范围的去掉
		//保留在j到i之间的q[hh]
		while(hh <= tt && q[hh] < j)
		{
    
    
			hh ++;
		}
		
		//维护单调队列中的不减序列,将a[i]加到其中
		while(hh <= tt && a[q[tt]] <= a[i]) tt--;
		q[++tt] = i;
		
		f[i] = f[j-1] + max(a[q[hh]], a[i]);
		//... 2 ... (j)1 ..."8" .1.3. "7" ..1. "6" ..4. "5" ... (i)2
		//就是q始终维护最后区间和小于m的区间上的不增序列,但无法决断出哪个
		//f[j]+max(a[j...i])是最小的,只能说最小的都在维护的序列中
		//因为f[7所在下标]+6肯定小于f[3所在下标]+7,但无法知道哪个最小,所以还需要枚举
		
		for(int k = hh; k <= tt; k ++)
		{
    
    
			f[i] = min(f[i], f[q[k]] + max(a[q[k+1]], a[i]));
		}
		
	}

	if(f[n] == 0) puts("-1");
	else
	printf("%lld\n", f[n]);
	
	return 0;
} 

最终优化:
我们发现,我们只需要对序列的插入,特定值删除,还有求最小值,所以平衡树可以完成这些操作,并且时间是log(n)
总的时间就是nlog(n) 就可以ac了
引入平衡树set
但这个序列可能会有重复元素,所以使用multiset

#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>
#include<set>

using namespace std;
typedef long long ll;

const int N = 1e5 +10;

int n;
int a[N], q[N];
ll m;
ll f[N];
multiset<ll> se;

inline ll get(int pos)
{
    
    
    return f[q[pos]] + a[q[pos + 1]];
}

inline void del(int pos)
{
    
    
    auto t = se.find(get(pos));
    se.erase(t);
}

inline int read()
{
    
    
    int x = 0, f = 1;
    char ch = getchar();
    while(ch < '0' || ch > '9')
    {
    
    
        if(ch == '-')
        {
    
    
            f = -1;
        }
        ch = getchar();
    }
    while(ch >= '0' && ch <= '9')
    {
    
    
        x = x*10 + ch - 48;
        ch = getchar();
    }
    return x * f;
}

int main(){
    
    
	cin >> n >> m;
	for(int i = 1; i <= n; i ++){
    
    
	    a[i] = read();
	}
	
	memset(f, 0x3f, sizeof(f));
	f[0] = 0;
	
	int hh = 0, tt = -1;
	ll cnt = 0;
	
	q[++tt] = 0;
	for(int i = 1, j = 1; i <= n; i ++)
	{
    
    
		cnt += a[i];
		while(cnt > m) {
    
    
			cnt -= a[j];
			j ++;
		}
		
		while(hh <= tt && q[hh] < j)
		{
    
    
		    if(tt > hh)
		    {
    
    
		        del(hh);
		    }
			hh ++;
		}
		
		
		while(hh <= tt && a[q[tt]] <= a[i]) {
    
    
		    if(tt > hh)
		    {
    
    
		        del(tt - 1);
		    }
		    tt--;
		}
		q[++tt] = i;
		
		if(tt > hh)
		{
    
    
		    se.insert(get(tt-1));
		}
		f[i] = f[j-1] + a[q[hh]];
		if(se.size())
		f[i] = min(f[i], *se.begin());
		
	}

	if(f[n] == 0) puts("-1");
	else
	printf("%lld\n", f[n]);
	
	return 0;
} 

猜你喜欢

转载自blog.csdn.net/qq_63092029/article/details/129578153