题目大意:
给一个数组,有两种操作
1:改变数组中的一个值
2: 求区间中成非递减的子序列有多少个
刚开始拿到题感觉:这不就是个板子吗?
然后这道板子题就卡了我两天,直到看了别人的代码才知道错哪儿了
节点需要维护的信息:
左右区间范围
当前区间中,合法子序列的数量
包含左端点的最长合法子序列的长度(llen)
包含右端点的最长合法子序列的长度(rlen)
让我们合并两个区间的时候,需要分情况讨论
如果左儿子最右端的数大于右儿子最左端的数:
父区间的合法子序列的数量就是两个子区间合法子序列的数量
如果左儿子最右端的数小于或等于右儿子最左端的数
此时父亲区间的的合法子序列的数量处理时两儿子的合之外,还要加上由于区间合并,多出来的合法子序列,其多出来的数量就是左儿子的rlen乘上右儿子的llen
此外,也可以用来以下代码来计算
int get(int x) {
//计算:长度为x的区间中有多少个子区间
return (x * (x + 1) / 2);
}
int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
int b = get(tr[u << 1].rlen);
int c = get(tr[u << 1 | 1].llen);
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum + (a - b - c);
题中一个重要的易错点就是query函数
当我们递归了两个子儿子后,如果左儿子最右端的数小于右儿子最左端的数,就还要考虑两儿子组合后多出来的结果,似乎套用上面的公式就行。
但是!
如果我们直接套用会导致:就算问询区间中不包含左儿子,我们也没有递归左儿子,也会导致在计算时额外多加了组合时多出来的数
这是最开始的写法
if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l]) {
int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
int b = get(tr[u << 1].rlen);
int c = get(tr[u << 1 | 1].llen);
//cout << a << " " << b << " " << c << endl;
res += (a - b - c);
}
如果我们在if上面添加一个限制条件:mid >= l && mid < r
会导致:我们递归了一个左儿子的一个右儿子,但在计算时会把整个左儿子的rlen拿过来算(实际上只需要算左儿子的右儿子的rlen)
所以用这个代码才能避免
if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l] ) {
int lsum = min(mid - l + 1, tr[u << 1].rlen);
int rsum = min(r - mid, tr[u << 1 | 1].llen);
if (lsum > 0 && rsum > 0)
res += lsum * rsum;
}
return res;
以下是全部代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
int n, q;
const int N = 2e5 + 10;
int w[N];
struct node {
//计算:区间中最大的连续子串长度 与中点向连的最大连续子串长度
int l, r;
int sum;//区间总个数
int llen, rlen;//以边界的最长单调
} tr[4 * N];
int get(int x) {
//计算:长度为x的区间中有多少个子区间
return (x * (x + 1) / 2);
}
void pushup(int u) {
tr[u].llen = tr[u << 1].llen;
tr[u].rlen = tr[u << 1 | 1].rlen;
if (w[tr[u << 1].r] > w[tr[u << 1 | 1].l]) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
} else {
int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
int b = get(tr[u << 1].rlen);
int c = get(tr[u << 1 | 1].llen);
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum + (a - b - c);
if (tr[u << 1].llen == tr[u << 1].r - tr[u << 1].l + 1) {
tr[u].llen = tr[u << 1].llen + tr[u << 1 | 1].llen;
}
if (tr[u << 1 | 1].rlen == tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) {
tr[u].rlen = tr[u << 1 | 1].rlen + tr[u << 1].rlen;
}
}
}
void build(int u, int l, int r) {
tr[u] = {
l, r, 1, 1, 1};
if (l == r)
return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int x, int v) {
if (tr[u].r == x && tr[u].l == x) {
w[x] = v;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if (mid >= x)
modify(u << 1, x, v);
else
modify(u << 1 | 1, x, v);
pushup(u);
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
int mid = tr[u].l + tr[u].r >> 1;
int res = 0;
if (mid >= l)
res = query(u << 1, l, r);
if (mid < r)
res += query(u << 1 | 1, l, r);
/*
这里是一个错误的写法
if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l] && mid >= l && mid < r) {
int a = get(tr[u << 1].rlen + tr[u << 1 | 1].llen);
int b = get(tr[u << 1].rlen);
int c = get(tr[u << 1 | 1].llen);
//cout << a << " " << b << " " << c << endl;
res += (a - b - c);
}
*/
if (w[tr[u << 1].r] <= w[tr[u << 1 | 1].l] ) {
int lsum = min(mid - l + 1, tr[u << 1].rlen);
int rsum = min(r - mid, tr[u << 1 | 1].llen);
if (lsum > 0 && rsum > 0)
res += lsum * rsum;
}
return res;
}
void solve() {
cin >> n >> q;
for (int i = 1; i <= n; i++)
cin >> w[i];
build(1, 1, n);
while (q--) {
int a, b, c;
cin >> a >> b >> c;
if (a == 1) {
modify(1, b, c);
} else {
cout << query(1, b, c) << endl;
}
}
}
signed main() {
ios::sync_with_stdio(false);
solve();
}