线段树之多lazy标记(HDU-4578)
一、首先是单懒标记
在使用线段树的时候往往都配合着 lazy 懒标记来提高区间修改 change_interval 的效率。假设现在有一颗表示 1-10 之间元素和的线段树,如下图所示,现在我们需要对区间 [6, 8] 进行 add 1 的操作,若我们一路向下搜索到叶子节点 [6, 6] 、[7, 7] 和 [8, 8] 三个点进行 add 1,然后再逐层往上反馈,这肯定是正确的,**但是可能做了无用功!**因为如果进行完这个操作后,要查询区间 [6, 8] 的元素和,我们只搜索到 [6, 8] 这个结点就可以得到答案,没有必要继续搜索到叶子节点进行累加,因此我们在对区间 [6, 8] 区间进行 add 3 即可,因为这个区间中有 3 个元素,之后添加在这个节点添加 lazy 标记 f = 1,这样就不用继续向下修改了,那如果要用到下面节点的值呢?那不是就出错了吗,因此我们需要在访问子节点的时候将该结点的懒标记 down 下传下去,这样就保证正确性。
相信来看这个博客的你们肯定是了解了懒标记之后才想学习多个懒标记的,因此这部分就打住!
二、多个懒标记
在有的题目中,仅仅使用单个懒标记不足以满足我们的需求,比如这道题目:HDU 4578
题目大意:给定一个长度为 n 的整形序列 {a1, a2, a3, … , an} ,初始序列为 {0, 0, 0, … , 0},有 m 个操作,其中有四种操作:区间add、区间mul、区间set以及区间求和(p 指定每个元素的次幂)。
1 x y c : 表示区间 [x, y] 加上 c
2 x y c : 表示区间 [x, y] 乘上 c
3 x y c : 表示区间 [x, y] 赋值为 c
4 x y p : 表示区间 [x, y] 的 p 次幂和(sum = axp + ax+1p + … + ayp)
(1 <= n, m <= 1e5, 1 <= x <= y <= n, 1 <= c <= 1e4, 1 <= p <= 3)
这道题目为什么只使用一个 lazy 标记不行呢?因为 add、mul 和 set 三种操作的效果不同,优先级也不同,堆积在一个懒标记之中无法区分。
因此我们需要设计三个懒标记 fadd、fmul 和 fset,其中 fset 的优先级最高,并且在给 fset 赋值的时候需要把当前的 fadd 以及 fmul 都初始化;其次是 fmul,乘法的优先级自然比加法高,在改变 fmul 的时候如果发现 fadd 存在,需要先将 fadd 乘上本来的 fmul 的值,之后再改变 fmul 的值(想一想,为什么);最后就是加法 fadd 了,在改变 fadd 的时候直接加上即可。
因此在 down 下传懒标记的时候也要按照优先级顺序进行下传:fset > fmul > fadd
除此之外,还有一个坑爹的地方就是第四种操作,需要求区间的 p 次幂和,这就造成了区间修改的困难。
由于 p 最多只有 3,因此我们可以在每个结点都保存这个区间的 1-3 次幂的和,分别是 sum1、sum2 以及 sum3,这样每次询问的时候直接返回对应的次幂和即可。但是这样还是有问题,我们怎么在不达到叶子结点的情况下正确改变当前区间的各次幂之和?
动动手,可以退出下面的结论:
(a * c)1 = a * c
(a * c)2 = a2 * c2
(a * c)3 = a3 * c3
(a + c)1 = a + c
(a + c)2 = a2 + c2 + 2ac
(a + c)3 = a3 + c3 + 3ac(a+ c)
这是单点的变化,对于区间的话,需要进行相加。
假设 arr[] = {a, b, c},我们分别对 [1, 3] 进行 add d 与 mul d 的操作,则:
初始 sum1 = a + b + c, sum2 = a2 + b2 + c2 , sum3 = a3 + b3 + c3
对 [1, 3] 进行 add d 的操作:
sum1 = (a+d) + (b+d) + (c+d) = a + b + c + 3 * d = sum1 + 3 * d
即 sum1 += (right-left+1) * d
sum2 = (a+d)2 + (b+d)2 + (c+d)2 = a2 + b2 + c2 + 3 * d2 + 2 * d * (a + b + c)
= sum2 + 3 * d2 + 2 ^ d ^ sum1
即 sum2 += (right-left+1) * d2 + 2d * sum1
sum3 = (a+d)3 + (b+d)3 + (c+d)3 = a3 + b3 + c3 + 3 * d3 + 3 * d * (a2 + b2 + c2 + d * (a + b + c))
= sum3 + 3 * d^3^ + 3 * d * (sum2 + d * (sum1))
即 sum3 += (right-left+1) * d3 + 3d(sum2 + d * sum1)
解决了这些问题,我们就可以写代码了。
参考代码:(问题需要进行取模,所以看起来很多很多的%mod,影响阅读)
// 多lazy标记 线段树
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
const int maxn = 1e5+10, maxm = 1e5+10, mod = 10007;
LL n, m, ans;
struct Node {
LL left, right;
LL sum1, sum2, sum3;
LL fset, fadd, fmul;
} tree[4*maxn];
void build(LL k, LL l, LL r) {
tree[k].left = l, tree[k].right = r;
tree[k].sum1 = tree[k].sum2 = tree[k].sum3 = 0;
tree[k].fset = tree[k].fadd = 0;
tree[k].fmul = 1;
if(l == r) return;
LL mid = (l+r)/2;
build(2*k, l, mid);
build(2*k+1, mid+1, r);
}
void downSet(LL k) {
if(tree[k].left == tree[k].right) return;
LL ffset = tree[k].fset;
if(ffset) {
tree[2*k].fset = tree[2*k+1].fset = ffset;
tree[2*k].fadd = tree[2*k+1].fadd = 0;
tree[2*k].fmul = tree[2*k+1].fmul = 1;
tree[2*k].sum1 = ((tree[2*k].right-tree[2*k].left+1)%mod*(ffset%mod))%mod;
tree[2*k].sum2 = (tree[2*k].sum1%mod*(ffset%mod))%mod;
tree[2*k].sum3 = (tree[2*k].sum2%mod*(ffset%mod))%mod;
tree[2*k+1].sum1 = ((tree[2*k+1].right-tree[2*k+1].left+1)%mod*(ffset%mod))%mod;
tree[2*k+1].sum2 = (tree[2*k+1].sum1%mod*(ffset%mod))%mod;
tree[2*k+1].sum3 = (tree[2*k+1].sum2%mod*(ffset%mod))%mod;
tree[k].fset = 0;
}
}
void downAdd(LL k) {
if(tree[k].left == tree[k].right) return;
LL ffadd = tree[k].fadd;
if(ffadd) {
tree[2*k].fadd = (tree[2*k].fadd%mod+ffadd%mod)%mod;
tree[2*k+1].fadd = (tree[2*k+1].fadd%mod+ffadd%mod)%mod;
LL len = tree[2*k].right-tree[2*k].left+1;
tree[2*k].sum3 = (tree[2*k].sum3%mod+len*ffadd*ffadd%mod*ffadd%mod+3*ffadd%mod*(tree[2*k].sum2%mod+tree[2*k].sum1%mod*ffadd%mod)%mod)%mod;
tree[2*k].sum2 = (tree[2*k].sum2%mod+len*ffadd*ffadd%mod+2*ffadd%mod*(tree[2*k].sum1)%mod)%mod;
tree[2*k].sum1 = (tree[2*k].sum1%mod+len*ffadd%mod)%mod;
len = tree[2*k+1].right-tree[2*k+1].left+1;
tree[2*k+1].sum3 = (tree[2*k+1].sum3%mod+len*ffadd*ffadd%mod*ffadd%mod+3*ffadd%mod*(tree[2*k+1].sum2%mod+tree[2*k+1].sum1%mod*ffadd%mod)%mod)%mod;
tree[2*k+1].sum2 = (tree[2*k+1].sum2%mod+len*ffadd*ffadd%mod+2*ffadd%mod*(tree[2*k+1].sum1)%mod)%mod;
tree[2*k+1].sum1 = (tree[2*k+1].sum1%mod+len*ffadd%mod)%mod;
tree[k].fadd = 0;
}
}
void downMul(LL k) {
if(tree[k].left == tree[k].right) return;
LL ffmul = tree[k].fmul;
if(ffmul != 1) {
if(tree[2*k].fadd) {
tree[2*k].fadd = (tree[2*k].fadd%mod*(ffmul%mod))%mod;
}
if(tree[2*k+1].fadd) {
tree[2*k+1].fadd = (tree[2*k+1].fadd%mod*(ffmul%mod))%mod;
}
tree[2*k].fmul = (tree[2*k].fmul%mod*(ffmul%mod))%mod;
tree[2*k+1].fmul = (tree[2*k+1].fmul%mod*(ffmul%mod))%mod;
tree[2*k].sum1 = (tree[2*k].sum1%mod*(ffmul%mod))%mod;
tree[2*k+1].sum1 = (tree[2*k+1].sum1%mod*(ffmul%mod))%mod;
ffmul = (ffmul%mod*(tree[k].fmul%mod))%mod;
tree[2*k].sum2 = (tree[2*k].sum2%mod*(ffmul%mod))%mod;
tree[2*k+1].sum2 = (tree[2*k+1].sum2%mod*(ffmul%mod))%mod;
ffmul = (ffmul%mod*(tree[k].fmul%mod))%mod;
tree[2*k].sum3 = (tree[2*k].sum3%mod*(ffmul%mod))%mod;
tree[2*k+1].sum3 = (tree[2*k+1].sum3%mod*(ffmul%mod))%mod;
tree[k].fmul = 1;
}
}
void down(LL k) {
if(tree[k].left == tree[k].right) return;
downSet(k);
downMul(k);
downAdd(k);
}
void pushUp(LL k) {
tree[k].sum1 = (tree[2*k].sum1%mod+tree[2*k+1].sum1%mod)%mod;
tree[k].sum2 = (tree[2*k].sum2%mod+tree[2*k+1].sum2%mod)%mod;
tree[k].sum3 = (tree[2*k].sum3%mod+tree[2*k+1].sum3%mod)%mod;
}
void add(LL k, LL l, LL r, LL c) {
if(tree[k].left >= l && tree[k].right <= r) {
tree[k].fadd = (tree[k].fadd%mod+c%mod)%mod;
LL len = tree[k].right-tree[k].left+1;
tree[k].sum3 = (tree[k].sum3%mod+len*c*c%mod*c%mod+3*c%mod*(tree[k].sum2%mod+tree[k].sum1%mod*c%mod)%mod)%mod;
tree[k].sum2 = (tree[k].sum2%mod+len*c*c%mod+2*c%mod*(tree[k].sum1%mod))%mod;
tree[k].sum1 = (tree[k].sum1%mod+len*c%mod)%mod;
return;
}
down(k);
LL mid = (tree[k].left+tree[k].right)/2;
if(mid >= l) add(2*k, l, r, c);
if(mid < r) add(2*k+1, l, r, c);
pushUp(k);
}
void mul(LL k, LL l, LL r, LL c) {
if(tree[k].left >= l && tree[k].right <= r) {
if(tree[k].fadd){
tree[k].fadd = (tree[k].fadd%mod*c)%mod;
}
tree[k].fmul = (tree[k].fmul%mod*c)%mod;
tree[k].sum1 = (tree[k].sum1%mod*c)%mod;
tree[k].sum2 = (tree[k].sum2%mod*c*c)%mod;
tree[k].sum3 = (tree[k].sum3%mod*c*c%mod*c)%mod;
return;
}
down(k);
LL mid = (tree[k].left+tree[k].right)/2;
if(mid >= l) mul(2*k, l, r, c);
if(mid < r) mul(2*k+1, l, r, c);
pushUp(k);
}
void change(LL k, LL l, LL r, LL c) {
if(tree[k].left >= l && tree[k].right <= r) {
tree[k].fset = c;
tree[k].fadd = 0;
tree[k].fmul = 1;
tree[k].sum1 = ((tree[k].right-tree[k].left+1)%mod*c)%mod;
tree[k].sum2 = (tree[k].sum1%mod*c)%mod;
tree[k].sum3 = (tree[k].sum2%mod*c)%mod;
return;
}
down(k);
LL mid = (tree[k].left+tree[k].right)/2;
if(mid >= l) change(2*k, l, r, c);
if(mid < r) change(2*k+1, l, r, c);
pushUp(k);
}
void ask(LL k, LL l, LL r, LL p) {
if(tree[k].left >= l && tree[k].right <= r) {
if(p == 1) {
ans = (ans%mod+tree[k].sum1%mod)%mod;
} else if(p == 2) {
ans = (ans%mod+tree[k].sum2%mod)%mod;
} else {
ans = (ans%mod+tree[k].sum3%mod)%mod;
}
return;
}
down(k);
LL mid = (tree[k].left+tree[k].right)/2;
if(mid >= l) ask(2*k, l, r, p);
if(mid < r) ask(2*k+1, l, r, p);
}
int main() {
while(scanf("%lld%lld", &n, &m) == 2 && n+m) {
build(1, 1, n);
for(LL i = 0; i < m; i++) {
LL op, l, r, c;
scanf("%lld%lld%lld%lld", &op, &l, &r, &c);
switch(op) {
case 1: {
add(1, l, r, c);
break;
}
case 2: {
mul(1, l, r, c);
break;
}
case 3: {
change(1, l, r, c);
break;
}
case 4: {
ans = 0;
ask(1, l, r, c);
printf("%lld\n", ans%mod);
break;
}
}
}
}
return 0;
}
【END】感谢观看!