原题链接:
裁剪序列
dp分析
这道题数据范围是1e5
朴素版代码,基于最基本的理解
(但只能过6个数据, n三方的复杂度)
#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
using namespace std;
typedef long long ll;
const int N = 1e5 +10;
int n;
int a[N];
ll m;
ll f[N], sum[N];
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
scanf("%d", &a[i]);
sum[i] = sum[i-1] + a[i];
}
memset(f, 0x3f, sizeof(f));
f[0] = 0;
for(int i = 1; i <= n; i ++){
for(int j = 1; j <= i; j ++)
{
if(sum[i] - sum[j - 1] > m) continue;
int maxx = 0;
//每次循环都需要查找一个区间最大值
//这里可以用ST表进行O(1)优化
for(int k = j; k <= i; k ++)
{
maxx = max(maxx, a[k]);
}
f[i] = min(f[i], f[j-1] + maxx);
}
}
if(f[n] == 0) puts("-1");
else
printf("%lld\n", f[n]);
return 0;
}
加入了快读 和 ST表查区间最大值可以过11个样例,但时间复杂度仍是n方还是AC不了
#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>
using namespace std;
typedef long long ll;
const int N = 1e5 +10;
int n;
int a[N], p[N][30];
ll m;
ll f[N], sum[N];
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-')
{
f = -1;
}
ch = getchar();
}
while(ch >= '0' && ch <= '9')
{
x = x*10 + ch - 48;
ch = getchar();
}
return x * f;
}
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
a[i] = read();
p[i][0] = a[i];
sum[i] = sum[i-1] + a[i];
}
memset(f, 0x3f, sizeof(f));
f[0] = 0;
//实现ST表
for(int j = 1; j <= log2(n); j ++)
{
for(int i = 1; i + (1 << j) - 1 <= n; i++)
{
p[i][j] = max(p[i][j-1], p[i+(1 << (j-1))][j-1]);
}
}
for(int i = 1; i <= n; i ++){
for(int j = 1; j <= i; j ++)
{
if(sum[i] - sum[j - 1] > m) continue;
int len = i - j + 1;
int lo = log2(len);
int maxx = max(p[j][lo], p[i - (1 << lo) + 1][lo]);
f[i] = min(f[i], f[j-1] + maxx);
}
}
if(f[n] == 0) puts("-1");
else
printf("%lld\n", f[n]);
return 0;
}
- 接下来就是最难的地方了,做题需要具体问题具体分析,分析出来这个题独有的性质,然后再转变思路,因为上面的思路极限就是n^2了,我们还需要寻找新的优化!
单调队列维护可能更新f[i]的所有可能
时间复杂度最坏O(n^2)即原数列是单调递减的并且和不超过m
#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>
using namespace std;
typedef long long ll;
const int N = 1e5 +10;
int n;
int a[N], q[N];
ll m;
ll f[N];
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-')
{
f = -1;
}
ch = getchar();
}
while(ch >= '0' && ch <= '9')
{
x = x*10 + ch - 48;
ch = getchar();
}
return x * f;
}
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
a[i] = read();
}
memset(f, 0x3f, sizeof(f));
f[0] = 0;
//单调队列实现
int hh = 0, tt = -1;
ll cnt = 0;
//因为a数组下标从1开始,所以先++t;
q[++tt] = 0;
for(int i = 1, j = 1; i <= n; i ++)
{
//i是前i个数, j是选择的要更新f[i]的f[j],
//f[i]=max(f[i],f[j]+max(a[j+1...i])
// hh tt //代表单调队列的下标,他的值是a的下标
cnt += a[i];//维护区间和小于m的(j...i)
while(cnt > m) {
cnt -= a[j];
j ++;
}
//j在cnt<=m的最左边
//这里是把q[hh]中,也就是对头超出范围的去掉
//保留在j到i之间的q[hh]
while(hh <= tt && q[hh] < j)
{
hh ++;
}
//维护单调队列中的不减序列,将a[i]加到其中
while(hh <= tt && a[q[tt]] <= a[i]) tt--;
q[++tt] = i;
f[i] = f[j-1] + max(a[q[hh]], a[i]);
//... 2 ... (j)1 ..."8" .1.3. "7" ..1. "6" ..4. "5" ... (i)2
//就是q始终维护最后区间和小于m的区间上的不增序列,但无法决断出哪个
//f[j]+max(a[j...i])是最小的,只能说最小的都在维护的序列中
//因为f[7所在下标]+6肯定小于f[3所在下标]+7,但无法知道哪个最小,所以还需要枚举
for(int k = hh; k <= tt; k ++)
{
f[i] = min(f[i], f[q[k]] + max(a[q[k+1]], a[i]));
}
}
if(f[n] == 0) puts("-1");
else
printf("%lld\n", f[n]);
return 0;
}
最终优化:
我们发现,我们只需要对序列的插入,特定值删除,还有求最小值,所以平衡树可以完成这些操作,并且时间是log(n)
总的时间就是nlog(n) 就可以ac了
引入平衡树set
但这个序列可能会有重复元素,所以使用multiset
#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>
#include<cmath>
#include<set>
using namespace std;
typedef long long ll;
const int N = 1e5 +10;
int n;
int a[N], q[N];
ll m;
ll f[N];
multiset<ll> se;
inline ll get(int pos)
{
return f[q[pos]] + a[q[pos + 1]];
}
inline void del(int pos)
{
auto t = se.find(get(pos));
se.erase(t);
}
inline int read()
{
int x = 0, f = 1;
char ch = getchar();
while(ch < '0' || ch > '9')
{
if(ch == '-')
{
f = -1;
}
ch = getchar();
}
while(ch >= '0' && ch <= '9')
{
x = x*10 + ch - 48;
ch = getchar();
}
return x * f;
}
int main(){
cin >> n >> m;
for(int i = 1; i <= n; i ++){
a[i] = read();
}
memset(f, 0x3f, sizeof(f));
f[0] = 0;
int hh = 0, tt = -1;
ll cnt = 0;
q[++tt] = 0;
for(int i = 1, j = 1; i <= n; i ++)
{
cnt += a[i];
while(cnt > m) {
cnt -= a[j];
j ++;
}
while(hh <= tt && q[hh] < j)
{
if(tt > hh)
{
del(hh);
}
hh ++;
}
while(hh <= tt && a[q[tt]] <= a[i]) {
if(tt > hh)
{
del(tt - 1);
}
tt--;
}
q[++tt] = i;
if(tt > hh)
{
se.insert(get(tt-1));
}
f[i] = f[j-1] + a[q[hh]];
if(se.size())
f[i] = min(f[i], *se.begin());
}
if(f[n] == 0) puts("-1");
else
printf("%lld\n", f[n]);
return 0;
}