You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.
We will ask you to perfrom the following operation:
- u v : ask for how many different integers that represent the weight of nodes there are on the path from u to v.
Input
In the first line there are two integers N and M.(N<=40000,M<=100000)
In the second line there are N integers.The ith integer denotes the weight of the ith node.
In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).
In the next M lines,each line contains two integers u v,which means an operation asking for how many different integers that represent the weight of nodes there are on the path from u to v.
Output
For each operation,print its result.
题意:给你一颗树,树上每个节点有一个权值,然后给你若干个询问,每次询问让你找出一条链上有多少个不同权值.
解题思路:如果这个问题是在一个序列上做,那么我们可以用莫队做,但是这个是在树上,所以我们先对树进行分块,然后同样也可以用莫队,但是树上莫队的转移就没有序列上莫队的转移那么容易了,这用需要用到Lca,然后用集合的对称差转移,还是很巧妙的。
#include <bits/stdc++.h> using namespace std; const int maxn = 40000 + 10; const int maxm = 100000 + 10; int N, M; int block;//块的大小 int nowblock;//当前块的编号 int sta[maxn];//存放待分配块的节点 int top;//栈顶指针 vector<int> g[maxn];//树 int pos[maxn];//节点所在的酷块的编号 int depth[maxn];//节点的深度 int value1[maxn];//节点原本的权值 int value2[maxn];//节点离散化之后的权值 int H[maxn];//用于离散化节点权值的数组 int num[maxn];//表示数字出现过的次数 int appear[maxn];//表示节点是否出现过 int father[maxn]; bool visit[maxn]; int Index[maxn<<2]; int dp[maxn<<2][25]; int First[maxn]; int Log[maxn<<2]; int res; struct query{ int l, r, u, v, id; bool operator <(const query &res) const{ if(l == res.l) return r < res.r; else return l < res.l; } }Query[maxm]; void init() { res = 1; top = 0; father[1] = -1; memset(visit, false, sizeof(visit)); memset(appear, 0, sizeof(appear)); memset(num, 0, sizeof(num)); block = (int)sqrt(N); nowblock = 0; for(int i = 1; i <= N; i++) g[i].clear(); } void initRmq() { Log[0] = -1; for(int i = 1; i < res; i++) { Log[i] = (i&(i - 1)) == 0?Log[i - 1] + 1:Log[i - 1]; } for(int i = 1; i < res; i++) { dp[i][0] = Index[i]; } for(int j = 1; j < 20; j++) { for(int i = 1; i < res&&(i + (1<<j) - 1) < res; i++) { dp[i][j] = (depth[dp[i][j - 1]] < depth[dp[i + (1<<(j - 1))][j - 1]])?dp[i][j - 1]:dp[i + (1<<(j - 1))][j - 1]; } } } int Rmq(int l,int r) { int dis = r - l + 1; int j = Log[dis]; int result = (depth[dp[l][j]] < depth[dp[r - (1<<j) + 1][j]])?dp[l][j]:dp[r - (1<<j) + 1][j]; return result; } void LCA(int root,int d)//获得 { First[root] = res; depth[root] = d; Index[res++] = root; visit[root] = true; for(int i = 0; i < g[root].size(); i++) { int v = g[root][i]; if(!visit[v]) { father[v] = root; LCA(v,d + 1); Index[res++] = root; } } } int getLca(int u, int v) { int f1 = First[u]; int f2 = First[v]; if(f1 > f2) swap(f1, f2); return Rmq(f1, f2); } int dfs_block(int u) { int sum = 0; visit[u] = true; for(int i = 0; i < g[u].size(); i++) { int v = g[u][i]; if(!visit[v]) { sum += dfs_block(v); if(sum >= block) { while(sum--) pos[sta[top--]] = nowblock; sum = 0; nowblock++; } } } sta[++top] = u; return sum + 1; } void initHash() { int tot = 0; for(int i = 1; i <= N; i++) { H[tot++] = value1[i]; } sort(H, H + tot); tot = unique(H, H + tot) - H; for(int i = 1; i <= N; i++) { value2[i] = lower_bound(H, H + tot, value1[i]) - H + 1; } } int L, R, ans; void work(int &v) { if(appear[v]) { if(--num[value2[v]] == 0) ans--; } else if(++num[value2[v]] == 1) ans++; appear[v] ^= 1; v = father[v]; } int result[maxm]; int main() { //freopen("C:\\Users\\creator\\Desktop\\in1.txt","r",stdin) ; scanf("%d%d", &N, &M); for(int i = 1; i <= N; i++) { scanf("%d", &value1[i]); } init(); for(int i = 1; i < N; i++) { int u, v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } LCA(1, 0); initRmq(); initHash(); memset(visit, false, sizeof(visit)); dfs_block(1); while(top) pos[sta[top--]] = nowblock; for(int i = 1; i <= M; i++) { int u, v; scanf("%d%d", &u, &v); if(pos[u] > pos[v]) swap(u, v); Query[i].l = pos[u]; Query[i].r = First[v]; Query[i].u = u; Query[i].v = v; Query[i].id = i; } sort(Query + 1, Query + M + 1); L = 1; R = 1; ans = 0; memset(visit, false, sizeof(visit)); for(int i = 1; i <= M; i++) { int u = Query[i].u; int v = Query[i].v; int id = Query[i].id; int lca = getLca(u, v); int lca1 = getLca(L, u); int lca2 = getLca(R, v); while(L != lca1) work(L); while(u != lca1) work(u); while(R != lca2) work(R); while(v != lca2) work(v); if(num[value2[lca]] == 0) result[id] = ans + 1; else result[id] = ans; L = Query[i].u; R = Query[i].v; } for(int i = 1; i <= M; i++) { printf("%d\n", result[i]); } return 0; }