思路
Accepted Code
#include <bits/stdc++.h>
using namespace std;
int n, u, v;
vector<int> mp[100005];
int f[100005];
inline int findf(int x) {
return (f[x] == x ? x : f[x] = findf(f[x])); }
vector<pair<int, int>> q;
vector<int> ansq[100005];
int ans[100005];
int rk[100005];
void dfs(int x, int f) {
for (int i = 0, sz = ansq[x].size(); i < sz; i++) {
if (ansq[x][i] == f) continue;
ans[ansq[x][i]] = ans[x] + 1;
dfs(ansq[x][i], x);
}
}
inline bool cmp(int a, int b) {
return rk[a] < rk[b]; }
int main() {
int t;
scanf("%d", &t);
while (t--) {
scanf("%d", &n);
q.clear();
for (int i = 1; i <= n; i++) {
mp[i].clear();
ansq[i].clear();
ans[i] = 1;
f[i] = i;
}
for (int i = 0; i < n - 1; i++) {
scanf("%d%d", &u, &v);
mp[u].push_back(v);
mp[v].push_back(u);
}
for (int i = 1; i <= n; i++) {
scanf("%d", &v);
q.push_back({
v, i});
}
sort(q.begin(), q.end());
for (int i = 0; i < n; i++) rk[q[i].second] = i + 1;
for (int i = 0; i < n; i++) {
u = q[i].second;
sort(mp[u].begin(), mp[u].end(), cmp);
for (int j = 0, sz = mp[u].size(); j < sz; j++) {
v = mp[u][j];
if (rk[v] > rk[u]) continue;
int fv = findf(v);
ansq[u].push_back(fv);
f[fv] = u;
}
}
ans[q[n - 1].second] = 1;
dfs(q[n - 1].second, -1);
for (int i = 1; i <= n; i++) printf("%d\n", ans[i]);
}
return 0;
}