题意:对一个长度为n的条染色,只能黑或者白,对第i块染色染成黑或白分别加分arr[i]和brr[i],给出m条线段,对这些线段同时染成某种颜色有额外的加分,求最大的加分数。
解析:比赛时用着明显错误的假算法冲过了中等难度,感觉题目的数据有点过水,不能保证代码完全正确 (⊙﹏⊙)
现在想来中等难度的O(n+m^2)的正确算法应该是:把黑白区间分别按照右边界从小到大排列,之后在dp的过程中到达右边界时进行dp,状态转移方程就是dp[i]=max(dp[i],dp[j-1]+pre[i]-pre[j-1]+sum),其中j是每一个(需要O(m)次)之前跑过的线段的左边界,sum是跑过的左边界大于等于j的线段加分和,右边那项代表的意思是把j到i全涂成某个颜色的最右解。
而当m到达3e5时,还是对黑白区间分别进行处理,但每次还是一个个线段来求就会超时,所以应该把取每一线段对结果的加分用线段树来表示体现,这里使用的线段树意义不好理解,是一个随着i的变化而不断更新的线段树,线段树第k个节点的值代表把k到i之间全部涂成颜色id时,从1到i的加分最大值,id只有0和1两种状态,代表全染成黑色和全染成白色两棵线段树。
一直觉得dp题看状态转移方程和题解解析很难看懂,还是直接看代码比较好理解转换逻辑。。。。
代码:
#include <bits/stdc++.h>
#define x first
#define y second
#define mid (l+r>>1)
#define lo (o<<1)
#define ro (o<<1|1)
using namespace std;
typedef long long ll;
typedef vector<int>vi;
typedef pair<int,int> pii;
struct tri{int x,y,z;ll dp;};
const int inf=0x3f3f3f3f;
const ll linf=0x3f3f3f3f3f3f3f3f;
const int maxn=3e5+10;
const ll mod=1e9+7;
const double PI=acos(0)*2;
bool cmp(tri a,tri b)
{
return a.y<b.y;
}
ll arr[maxn],brr[maxn];
int n,m;
vector<tri>sa,sb;
ll ma[2][maxn<<2],lazy[2][maxn<<2];
void build(int o=1,int l=1,int r=n)
{
ma[0][o]=ma[1][o]=-linf;
if(l==r)
{
return;
}
build(lo,l,mid);
build(ro,mid+1,r);
}
void pushdown(int id,int o,int l,int )
{
if(!lazy[id][o])return;
ma[id][lo]+=lazy[id][o];
ma[id][ro]+=lazy[id][o];
lazy[id][lo]+=lazy[id][o];
lazy[id][ro]+=lazy[id][o];
lazy[id][o]=0;
}
void add(int id,int ql,int qr,ll v,int o=1,int l=1,int r=n)
{
if(ql<=l&&r<=qr)
{
ma[id][o]+=v;
lazy[id][o]+=v;
return;
}
pushdown(id,o,l,r);
if(ql<=mid)add(id,ql,qr,v,lo,l,mid);
if(qr>mid)add(id,ql,qr,v,ro,mid+1,r);
ma[id][o]=max(ma[id][lo],ma[id][ro]);
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
// freopen("in.txt","r",stdin);
cin>>n>>m;
build();
for(int i=1;i<=n;i++)cin>>arr[i];
for(int i=1;i<=n;i++)cin>>brr[i];
for(int i=0;i<m;i++)
{
int a,b,c,d;
cin>>a>>b>>c>>d;
if(a==1)sa.push_back({b,c,d});
else sb.push_back({b,c,d});
}
sa.push_back({0,n+1,0});
sb.push_back({0,n+1,0});//方便边界处理
sort(sa.begin(),sa.end(),cmp);
sort(sb.begin(),sb.end(),cmp);
auto a=sa.begin(),b=sb.begin();
ll ans=0;
for(int i=1;i<=n;i++)
{
add(0,1,i,arr[i]);//第i个一定是黑,这样才能保证从j到i都是黑
add(1,1,i,brr[i]);
add(0,i,i,linf+ans);//开始时为方便都设成了linf,这里补上去
add(1,i,i,linf+ans);//这里还要加上前i+1个的最优解
while(a->y==i)
{
add(0,1,a->x,a->z);//1到a->x之间任一值到i全是黑,那么就可以额外加分a->z
a++;
}
while(b->y==i)
{
add(1,1,b->x,b->z);
b++;
}
ans=max(ma[0][1],ma[1][1]);//ans代表到i的最右解,相当与dp[i]
}
cout<<ans<<endl;
return 0;
}