题目描述:
给定长度为N的数列A,以及M条指令,每条指令可能是以下两种之一:
1、“1 x y”,查询区间 [x,y] 中的最大连续子段和,即
2、“2 x y”,把 A[x] 改成 y。
对于每个查询指令,输出一个整数表示答案。
思路:
首先可以看出来这是一个区间查询,单点更新,线段树搞一下嘛。题目要求的是,最大连续字段和,可以用线段树维护四个变量,我直接写出线段树的结构
struct p{
ll l, r, lmax, rmax, max, sum;
void init() {
this->lmax = this->max = this->rmax = -inf;
this->sum = 0;
}
} ;
lmax : 从最左端开始的最大连续字段和
rmax : 同上,从最右边开始的最大连续字段和
max : 区间内的最大连续字段和
sum : 区间和
想一下它们的联系,其实很简单。
- 区间和就是(左右两个子区间的区间和)的和
- lmax就是(左区间的lmax)和(左边区间和与右边lmax 的和)的最大值 (只要左边,和左边全要右边要一部分)
- rmax就是(右区间的rmax)和(右边区间和与左边rmax的和)的最大值 (只要右边,和右边全要左边要一部分)
- max就是(左区间的最大值)和(右区间最大值)以及(左边区间的rmax + 右边区间的lmax)的最大值
其实就是一个区间合并
有一个小问题,查询的时候,需要用到左右区间的最大值,和左区间的rmax,右区间的lmax,但是函数返回值只能有一个,因此可以返回一个结构体类型,将上述信息写入结构体中返回出来就可以用了,每一次返回的前,将里面的lmax,rmax,max,sum值一并更新即可,当然单点更新的时候也是要在回溯的时候更新的
最后查询出来的结果就是一个结构体,它的max属性就是答案啦
代码:
#include <bits/stdc++.h>
#define mem(a, b) memset(a, b, sizeof a)
#pragma warning (disable:6031)
#pragma warning (disable:4996)
#define inf 0x3f3f3f3f
using namespace std;
const int N = 310;
typedef long long ll;
struct p{
ll l, r, lmax, rmax, max, sum;
void init() {
this->lmax = this->max = this->rmax = -inf;
this->sum = 0;
}
} c[N * 4];
void build(ll l, ll r, ll k) {
c[k].l = l;
c[k].r = r;
if (l == r) {
scanf("%lld", &c[k].max);
c[k].lmax = c[k].rmax = c[k].sum = c[k].max;
return;
}
ll mid = (l + r) / 2;
build(l, mid, k * 2);
build(mid + 1, r, k * 2 + 1);
c[k].lmax = max(c[k << 1].lmax, c[k << 1].sum + c[k << 1 | 1].lmax);
c[k].rmax = max(c[k << 1 | 1].rmax, c[k << 1 | 1].sum + c[k << 1].rmax);
c[k].sum = c[k << 1].sum + c[k << 1 | 1].sum;
c[k].max = max(c[k << 1].max, max(c[k << 1 | 1].max, c[k << 1].rmax + c[k << 1 | 1].lmax));
}
void update(ll ind, ll k, ll d) {
if (c[k].l == c[k].r) {
c[k].lmax = c[k].max = c[k].rmax = c[k].sum = d;
return;
}
ll mid = (c[k].l + c[k].r) / 2;
if (ind <= mid) {
update(ind, k * 2, d);
}
else update(ind, k * 2 + 1, d);
c[k].sum = c[k << 1].sum + c[k << 1 | 1].sum;
c[k].lmax = max(c[k << 1].lmax, c[k << 1].sum + c[k << 1 | 1].lmax);
c[k].rmax = max(c[k << 1 | 1].rmax, c[k << 1 | 1].sum + c[k << 1].rmax);
c[k].max = max(c[k << 1].max, max(c[k << 1 | 1].max, c[k << 1].rmax + c[k << 1 | 1].lmax));
}
p query(ll l, ll r, ll k) {
if (c[k].l >= l && c[k].r <= r) {
return c[k];
}
ll mid = (c[k].l + c[k].r) / 2;
p res;
p ta, tb;
res.init();
ta.init();
tb.init();
if (l <= mid) {
ta = query(l, r, k << 1);
}
if (r > mid) {
tb = query(l, r, k << 1 | 1);
}
res.sum = ta.sum + tb.sum;
res.lmax = max(ta.lmax, ta.sum + tb.lmax);
res.rmax = max(tb.rmax, tb.sum + ta.rmax);
res.max = max(ta.max, max(tb.max, ta.rmax + tb.lmax));
return res;
}
int main()
{
ll n, m;
scanf("%lld %lld", &n, &m);
build(1, n, 1);
while (m--) {
int op;
ll x, y;
scanf("%d %lld %lld", &op, &x, &y);
if (op == 1) {
if (x > y)swap(x, y);
printf("%lld\n", query(x, y, 1).max);
}
else {
update(x, 1, y);
}
}
return 0;
}