1036: [ZJOI2008]树的统计Count
Time Limit: 10 Sec Memory Limit: 162 MBSubmit: 19818 Solved: 8066
[ Submit][ Status][ Discuss]
Description
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身
Input
输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。
Output
对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。
Sample Input
4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4
Sample Output
4
1
2
2
10
6
5
6
5
16
1
2
2
10
6
5
6
5
16
HINT
Source
【思路】
朴素算法对每个操作都执行一遍深搜,不可取,需要一种能够记住点对点路径或者部分路径的方式。所以对整棵树进行轻重链剖分,把同一条链的节点映射到一个连续区间,再对其使用线段树维护。思想是:每个节点都属于某一条链,每条链都有一个唯一顶端,那么如果两个点具有同一个顶端,则位于同一条链上,那么其间的路径信息便可较快获取。
【代码】
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; const int MAXN = 30005, INF = 0x3f3f3f3f; struct edge { int to, next; }; struct segment { int left, right, mid; int sum, mx; }; int n, q, cnt, tot; int head[MAXN], p[MAXN], in[MAXN], id[MAXN],fa[MAXN], top[MAXN], sz[MAXN], max_son[MAXN], deep[MAXN]; edge e[MAXN << 1]; segment tree[MAXN << 2]; void addedge(int from, int to) { ++cnt; e[cnt].to = to; e[cnt].next = head[from]; head[from] = cnt; } void dfs_1(int u, int father, int depth) { deep[u] = depth; fa[u] = father; max_son[u] = 0; sz[u] = 1; for (int i = head[u]; i != 0; i = e[i].next) { int v = e[i].to; if (v == fa[u]) continue; dfs_1(v, u, depth + 1); sz[u] += sz[v]; if (sz[max_son[u]] < sz[v]) max_son[u] = v; } } void dfs_2(int u, int tp) { in[u] = ++tot; id[tot] = u; top[u] = tp; if (max_son[u] != 0) dfs_2(max_son[u], tp); for (int i = head[u]; i != 0; i = e[i].next) { int v = e[i].to; if (v == fa[u] || v == max_son[u]) continue; dfs_2(v, v); } } void build(int left, int right, int root) { tree[root].left = left; tree[root].right = right; tree[root].mid = (left + right) >> 1; if (left == right) { tree[root].mx = tree[root].sum = p[id[left]]; return; } build(left, tree[root].mid, root << 1); build(tree[root].mid + 1, right, root << 1 | 1); tree[root].sum = tree[root << 1].sum + tree[root << 1 | 1].sum; tree[root].mx = max(tree[root << 1].mx, tree[root << 1 | 1].mx); } void modify(int index, int num, int root) { if (tree[root].left == tree[root].right) { tree[root].mx = tree[root].sum = num; return; } if (index <= tree[root].mid) modify(index, num, root << 1); if (index >= tree[root].mid + 1) modify(index, num, root << 1 | 1); tree[root].mx = max(tree[root << 1].mx, tree[root << 1 | 1].mx); tree[root].sum = tree[root << 1].sum + tree[root << 1 | 1].sum; } int sum_query(int l, int r, int root) { if (l <= tree[root].left && tree[root].right <= r) return tree[root].sum; int ans = 0; if (l <= tree[root].mid) ans += sum_query(l, r, root << 1); if (r >= tree[root].mid + 1) ans += sum_query(l, r, root << 1 | 1); return ans; } int max_query(int l, int r, int root) { if (l <= tree[root].left && tree[root].right <= r) return tree[root].mx; int ans = -INF; if (l <= tree[root].mid) ans = max(ans, max_query(l, r, root << 1)); if (r >= tree[root].mid + 1) ans = max(ans, max_query(l, r, root << 1 | 1)); return ans; } int main() { cnt = 0; memset(head, 0, sizeof(head)); scanf("%d", &n); for (int i = 1; i <= n - 1; i++) { int a, b; scanf("%d %d", &a, &b); addedge(a, b); addedge(b, a); } for (int i = 1; i <= n; i++) scanf("%d", &p[i]); dfs_1(1, 0, 1); tot = 0; dfs_2(1, 1); build(1, n, 1); scanf("%d", &q); while (q--) { char mes[7]; int u, v; scanf("%s %d %d", mes, &u, &v); if (mes[0] == 'C') modify(in[u], v, 1); if (mes[1] == 'M') { int ans = -INF; while (top[u] != top[v]) { if (deep[top[u]] < deep[top[v]]) swap(u, v); ans = max(ans, max_query(in[top[u]], in[u], 1)); u = fa[top[u]]; } if (deep[u] > deep[v]) swap(u, v); ans = max(ans, max_query(in[u], in[v], 1)); printf("%d\n", ans); } if (mes[1] == 'S') { int ans = 0; while (top[u] != top[v]) { if (deep[top[u]] < deep[top[v]]) swap(u, v); ans += sum_query(in[top[u]], in[u], 1); u = fa[top[u]]; } if (deep[u] > deep[v]) swap(u, v); ans += sum_query(in[u], in[v], 1); printf("%d\n", ans); } } return 0; }