题目
有一棵点数为 N 的树,以点 1 为根,且树点有边权。然后有 M 个操作,
分为三种:操作 1 :把某个节点 x 的点权增加 a 。操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
indata:
5 5
1 2 3 4 5
1 2
1 4
2 3
2 5
3 3
1 2 1
3 5
2 1 2
3 3
oudata:
6
9
13
树链剖分
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
using LL = long long;
int N, M;
struct edge {
int next, to;
} e[200005];
int head[100005], cnt_adj;
void add_edge(int, int);
int val[100005];
int siz[100005], son[100005], dep[100005], fa[100005], top[100005], rnk[100005], tid[100005];
void dfs1(int u, int father, int depth);
void dfs2(int u, int start);
struct ST {
int left, right;
LL sum;
} st[400005];
void build(int, int, int);
void update(int, int, int, LL);
LL query(int, int, int);
void update_single(int, LL);
void update_subtree(int, LL);
LL query_path(int, int);
int main(){
ios::sync_with_stdio(false);
cin >> N >> M;
int i;
for(i = 1; i <= N; i++) cin >> val[i];
for(i = 1; i < N; i++){
int x, y;
cin >> x >> y;
add_edge(x, y);
add_edge(y, x);
}
dfs1(1, 0, 1);
dfs2(1, 1);
build(1, 1, N);
for(i = 1; i <= M; i++){
int op, x, key;
cin >> op >> x;
if(op == 1) cin >> key, update_single(x, key);
else if(op == 2) cin >> key, update_subtree(x, key);
else if(op == 3) cout << query_path(x, 1) << endl;
}
return 0;
}
LL query_path(int x, int y){
LL res = 0;
int fx = top[x], fy = top[y];
while(fx != fy){
if(dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
res += query(1, tid[fx], tid[x]);
x = fa[fx];
fx = top[x], fy = top[y];
}
if(tid[x] > tid[y]) swap(x, y);
res += query(1, tid[x], tid[y]);
return res;
}
void update_single(int u, LL key){
update(1, tid[u], tid[u], key);
}
void update_subtree(int u, LL key){
update(1, tid[u], tid[u] + siz[u] - 1, key);
}
void build(int pos, int l, int r){
st[pos].left = l, st[pos].right = r;
if(l == r)
st[pos].sum = val[rnk[l]];
else{
int mid = (l + r) >> 1;
int lc = pos << 1, rc = lc + 1;
build(lc, l, mid);
build(rc, mid + 1, r);
st[pos].sum = st[lc].sum + st[rc].sum;
}
}
LL tag[400005];
void pushdown(int);
void update(int pos, int l, int r, LL key){
if(l <= st[pos].left && st[pos].right <= r){
st[pos].sum += (st[pos].right - st[pos].left + 1) * key;
tag[pos] += key;
}else{
if(tag[pos]) pushdown(pos);
int mid = (st[pos].left + st[pos].right) >> 1;
int lc = pos << 1, rc = lc + 1;
if(l <= mid) update(lc, l, r, key);
if(mid < r) update(rc, l, r, key);
st[pos].sum = st[lc].sum + st[rc].sum;
}
}
LL query(int pos, int l, int r){
if(l <= st[pos].left && st[pos].right <= r)
return st[pos].sum;
int mid = (st[pos].left + st[pos].right) >> 1;
int lc = pos << 1, rc = lc + 1;
LL res = 0;
if(tag[pos]) pushdown(pos);
if(l <= mid) res += query(lc, l, r);
if(mid < r) res += query(rc, l, r);
return res;
}
void pushdown(int pos){
int lc = pos << 1, rc = lc + 1;
st[lc].sum += (st[lc].right - st[lc].left + 1) * tag[pos];
st[rc].sum += (st[rc].right - st[rc].left + 1) * tag[pos];
tag[lc] += tag[pos], tag[rc] += tag[pos];
tag[pos] = 0;
}
void dfs1(int u, int father, int depth){
siz[u] = 1, fa[u] = father, dep[u] = depth;
int i;
for(i = head[u]; i; i = e[i].next){
int v = e[i].to;
if(v != father){
dfs1(v, u, depth + 1);
siz[u] += siz[v];
if(siz[v] > siz[son[u]])
son[u] = v;
}
}
}
int cnt_tid;
void dfs2(int u, int start){
tid[u] = ++cnt_tid;
rnk[cnt_tid] = u;
top[u] = start;
if(son[u]) dfs2(son[u], start);
int i;
for(i = head[u]; i; i = e[i].next){
int v = e[i].to;
if(v != fa[u] && v != son[u])
dfs2(v, v);
}
}
void add_edge(int x, int y){
e[++cnt_adj].to = y;
e[cnt_adj].next = head[x];
head[x] = cnt_adj;
}