题目描述
M公司是一个非常庞大的跨国公司,在许多国家都设有它的下属分支机构或部门。为了让分布在世界各地的N个部门之间协同工作,公司搭建了一个连接整个公司的通信网络。该网络的结构由N个路由器和N-1条高速光缆组成。每个部门都有一个专属的路由器,部门局域网内的所有机器都联向这个路由器,然后再通过这个通信子网与其他部门进行通信联络。该网络结构保证网络中的任意两个路由器之间都存在一条直接或间接路径以进行通信。 高速光缆的数据传输速度非常快,以至于利用光缆传输的延迟时间可以忽略。但是由于路由器老化,在这些路由器上进行数据交换会带来很大的延迟。而两个路由器之间的通信延迟时间则与这两个路由器通信路径上所有路由器中最大的交换延迟时间有关。作为M公司网络部门的一名实习员工,现在要求你编写一个简单的程序来监视公司的网络状况。该程序能够随时更新网络状况的变化信息(路由器数据交换延迟时间的变化),并且根据询问给出两个路由器通信路径上延迟第k大的路由器的延迟时间。
【任务】 你的程序从输入文件中读入N个路由器和N-1条光缆的连接信息,每个路由器初始的数据交换延迟时间Ti,以及Q条询问(或状态改变)的信息。并依次处理这Q条询问信息,它们可能是:
由于更新了设备,或者设备出现新的故障,使得某个路由器的数据交换延迟时间发生了变化。
查询某两个路由器a和b之间的路径上延迟第k大的路由器的延迟时间。
输入输出格式
输入格式:第一行为两个整数N和Q,分别表示路由器总数和询问的总数。
第二行有N个整数,第i个数表示编号为i的路由器初始的数据延迟时间Ti。
紧接着N-1行,每行包含两个整数x和y。表示有一条光缆连接路由器x和路由器y。
紧接着是Q行,每行三个整数k、a、b。
如果k=0,则表示路由器a的状态发生了变化,它的数据交换延迟时间由Ta变为b。
如果k>0,则表示询问a到b的路径上所经过的所有路由器(包括a和b)中延迟第k大的路由器的延迟时间。注意a可以等于b,此时路径上只有一个路由器。
对于每一个第二种询问(k>0),输出一行。包含一个整数为相应的延迟时间。如果路径上的路由器不足k个,则输出信息“invalid request!”(全部小写不包含引号,两个单词之间有一个空格)。
输入输出样例
说明
测试数据满足N,Q<=80000,任意一个路由器在任何时刻都满足延迟时间小于10^8。对于所有询问满足0<=K<=N 。
Solution
简要题意:树上支持点修改的路径第k大(并不是升序第k名)查询。
分开考虑这两个问题。对于修改操作,我们要修改以其为根的整棵子树的那些权值线段树。在序列操作时,我们常常会用BIT来套,那么考虑把树转化为dfs序列即可,于是就变成了一个序列进行区间操作,常规差分。
对于查询操作,若不带修改时,我们当然可以用 v[x] + v[y] - v[lca(x, y)] - v[f[lca(x, y)]] 作为权值跑主席树。现在只不过每一个点都被差分了,于是前缀和即为所求。
比较考验代码能力。
后续我会更新这道题的暴力树链剖分+线段树套平衡树解法!
Code
#include <cstdio> #include <cstring> #include <algorithm> #define N 80010 #define M 8000010 using namespace std; inline char gc() { static char now[1<<16], *S, *T; if(S == T) {T = (S = now) + fread(now, 1, 1<<16, stdin); if(S == T) return EOF;} return *S++; } inline int read() { int x = 0, f = 1; char c = gc(); while(c < '0' || c > '9') {if(c == '-') f = -1; c = gc();} while(c >= '0' && c <= '9') {x = x * 10 + c - 48; c = gc();} return x * f; } struct edge {int to, next;}e[N<<1]; int a[N], head[N], f[N][17], b[N<<1], A[N], B[N], C[N], st[N], ed[N], root[N], dep[N], bit[N]; int L[M], R[M], v[M]; int n, q, cnt, m, dfn, len; inline void ins(int x, int y) {e[++cnt].to = y; e[cnt].next = head[x]; head[x] = cnt;} void add(int &now, int pre, int l, int r, int x, int k) { if(!now) now = ++len; v[now] = v[pre] + k; if(l == r) return ; int mid = (l + r)>>1; if(x <= mid) R[now] = R[pre], add(L[now], L[pre], l, mid, x, k); else L[now] = L[pre], add(R[now], R[pre], mid + 1, r, x, k); } void dfs(int x, int fa) { f[x][0] = fa; add(root[x], root[fa], 1, m, a[x], 1); st[x] = ++dfn; for(int i = head[x]; i; i = e[i].next) if(e[i].to != fa) dep[e[i].to] = dep[x] + 1, dfs(e[i].to, x); ed[x] = dfn; } inline void buildlca() { for(int j = 1; j < 17; ++j) for(int i = 1; i <= n; ++i) f[i][j] = f[f[i][j - 1]][j - 1]; } inline int getlca(int x, int y) { if(dep[x] < dep[y]) swap(x, y); for(int i = 16; i >= 0; --i) if(dep[f[x][i]] >= dep[y]) x = f[x][i]; if(x == y) return x; for(int i = 16; i >= 0; --i) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; return f[x][0]; } int temp[N], c[N], d[N]; inline void bitup(int x, int pos, int k) {for(int i = x; i <= n; i+= i & -i) temp[i] = 0, add(temp[i], bit[i], 1, m, pos, k), bit[i] = temp[i];} int numc, numd; inline void get(int x, int p) { if(p == 0) { d[++numd] = root[x]; x = st[x]; for(; x; x-= x & -x) d[++numd] = bit[x]; }else { c[++numc] = root[x]; x = st[x]; for(; x; x-= x & -x) c[++numc] = bit[x]; } } inline int query(int k) { int l = 1, r = m, t, tt; while(l < r) { t = tt = 0; for(int i = 1; i <= numc; ++i) t+= v[R[c[i]]], tt+= v[c[i]]; for(int i = 1; i <= numd; ++i) t-= v[R[d[i]]], tt-= v[d[i]]; if(tt < k) return -1; if(k <= t) { for(int i = 1; i <= numc; ++i) c[i] = R[c[i]]; for(int i = 1; i <= numd; ++i) d[i] = R[d[i]]; l = ((l + r)>>1) + 1; }else { for(int i = 1; i <= numc; ++i) c[i] = L[c[i]]; for(int i = 1; i <= numd; ++i) d[i] = L[d[i]]; r = (l + r)>>1; k-= t; } } return l; } int main() { n = read(); q = read(); m = 0; for(int i = 1; i <= n; ++i) a[i] = read(), b[++m] = a[i]; memset(head, 0, sizeof(head)); cnt = 1; for(int i = 1; i < n; ++i) {int x = read(), y = read(); ins(x, y); ins(y, x);} for(int i = 1; i <= q; ++i) { A[i] = read(); B[i] = read(); C[i] = read(); if(!A[i]) b[++m] = C[i]; } sort(b+1, b+m+1); m = unique(b+1, b+m+1) - b - 1; for(int i = 1; i <= n; ++i) a[i] = lower_bound(b+1, b+m+1, a[i]) - b; dfn = len = 0; dep[1] = 1; dfs(1, 0); buildlca(); for(int i = 1; i <= n; ++i) bit[i] = root[0]; for(int i = 1; i <= m; ++i) { int x = B[i], y = C[i]; if(!A[i]) { y = lower_bound(b+1, b+m+1, y) - b; bitup(st[x], a[x], -1); bitup(ed[x] + 1, a[x], 1); bitup(st[x], y, 1); bitup(ed[x] + 1, y, -1); a[x] = y; }else { int z = getlca(x, y); numc = numd = 0; get(x, 1); get(y, 1); get(z, 0); get(f[z][0], 0); int ans = query(A[i]); if(ans == -1) puts("invalid request!"); else printf("%d\n", b[ans]); } } return 0; }