在dfs1里先处理出重儿子
BZOJ传送门:点击打开链接
#include<cstdio> #include<cstring> #include<algorithm> #include<iostream> #include<cctype> using namespace std; const int inf = 0x7f7f7f7f, maxn = 30007; #define ls p<<1 #define rs p<<1|1 struct edge { int v, nxt; }e[maxn << 1]; int head[maxn], eid = 0, siz[maxn], dep[maxn], fa[maxn], top[maxn], l[maxn], tot, val[maxn]; int sum[maxn << 2], n, q, maxx[maxn << 2],son[maxn]; /* p邻接表 siz子树大小 dep结点深度 fa父结点 top重链头 l是dfs序 */ void insert(int u, int v) { e[++eid].v = v; e[eid].nxt = head[u]; head[u] = eid; e[++eid].v = u; e[eid].nxt = head[v]; head[v] = eid; } char cmd[10]; void dfs1(int u) { siz[u] = 1; int i,v; for (i = head[u]; i; i = e[i].nxt) { if (!siz[v = e[i].v]) { dep[v] = dep[u] + 1; fa[v] = u; dfs1(v); siz[u] += siz[v]; if (siz[v] > siz[son[u]]) son[u] = v; } } } void dfs2(int u, int t) { l[u] = ++tot; top[u] = t; int i,v; if (son[u] != 0) dfs2(son[u], t); for (i = head[u]; i; i = e[i].nxt) { v = e[i].v; if (dep[v]>dep[u] && v != son[u]) dfs2(v, v); } } void pushup(int p) { sum[p] = sum[ls] + sum[rs]; maxx[p] = max(maxx[ls], maxx[rs]); } void modify(int p, int l, int r, int x, int c) { if (l == r) { sum[p] = maxx[p] = c; return; } int mid = (l + r) >> 1; if (x <= mid) modify(ls, l, mid, x, c); else modify(rs, mid + 1, r, x, c); pushup(p); } int querysum(int p, int l, int r, int x, int y) { if (x <= l&&r <= y) return sum[p]; int mid = (l + r) >> 1, res = 0; if (x <= mid) res += querysum(ls, l, mid, x, y); if (y > mid) res += querysum(rs, mid + 1, r, x, y); return res; } int querymax(int p, int l, int r, int x, int y) { if (x <= l&&r <= y) return maxx[p]; int mid = (l + r) >> 1, res = -inf; //点权在-30000~30000 if (x <= mid) res = max(res, querymax(ls, l, mid, x, y)); if (y > mid) res = max(res, querymax(rs, mid + 1, r, x, y)); return res; } int getsum(int x, int y) { int res = 0; while (top[x] != top[y]) { if (dep[top[x]] < dep[top[y]]) swap(x, y); res += querysum(1, 1, n, l[top[x]], l[x]); x = fa[top[x]]; } if (l[x] > l[y]) swap(x, y); res += querysum(1, 1, n, l[x], l[y]); return res; } int getmax(int x, int y) { int res = -inf; while (top[x] != top[y]) { if (dep[top[x]] < dep[top[y]]) swap(x, y); res = max(res, querymax(1, 1, n, l[top[x]], l[x])); x = fa[top[x]]; } if (l[x] > l[y]) swap(x, y); res = max(res, querymax(1, 1, n, l[x], l[y])); return res; } int u; inline int read() { int s = 0, f = 1; char c = getchar(); while (c<'0' || c>'9') { if (c == '-') f = -1; c = getchar(); } while (c >= '0'&&c <= '9') { s = s * 10 + c - '0'; c = getchar(); } return s*f; } int main() { n = read(); int i,v; for (i = 1; i < n; i++) { u = read(); v = read(); insert(u, v); } for (i = 1; i <= n; i++) val[i] = read(); dfs1(1); dfs2(1, 1); for (int i = 1; i <= n; i++) modify(1, 1, n, l[i], val[i]); q = read(); char c,j=0; for (int i = 1; i <= q; i++) { scanf("%s" , cmd); u = read(); v = read(); switch (cmd[1]) { case 'H': val[u] = v; modify(1, 1, n, l[u], v); break; case 'M': printf("%d\n", getmax(u, v)); break; default: printf("%d\n", getsum(u, v)); } } return 0; }