版权声明:欢迎大家转载,转载请注明出处 https://blog.csdn.net/hao_zong_yin/article/details/83148121
上来按照dp的思想没什么头绪,因为5e4*(1<<10)有点大,所以往暴力上想了,树上暴力的话一般是往点分治上想,稍加思考发现这题只要枚举子集就可以在n(log(n))^2内解决,注意root是全局变量会改变,要存一下,因为这个直接自闭
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 5e4 + 10;
const int INF = 0x3f3f3f3f;
typedef long long LL;
int N, K, all, a[maxn], vis[maxn];
LL ans;
vector<int> G[maxn];
int sz[maxn], dp[maxn], root, SZ;
vector<int> sta;
LL num[1500];
void getroot(int f, int u) {
sz[u] = 1, dp[u] = 0;
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v == f || vis[v]) continue;
getroot(u, v);
sz[u] += sz[v];
dp[u] = max(dp[u], sz[v]);
}
dp[u] = max(dp[u], SZ - sz[u]);
if (dp[u] < dp[root]) root = u;
}
void getsta(int f, int u, int s) {
sta.push_back(s);
for (int i = 0; i < G[u].size(); i++) {
int v = G[u][i];
if (v == f || vis[v]) continue;
getsta(u, v, (s|(1<<a[v])));
}
}
LL getans(int u, int s) {
sta.clear();
getsta(0, u, s);
LL res = 0;
for (int i = 0; i <= all; i++) num[i] = 0;
for (int i = 0; i < sta.size(); i++) num[sta[i]]++;
for (int i = 0; i < sta.size(); i++) {
num[sta[i]]--;
res += num[all];
for (int s0 = sta[i]; s0; s0 = ((s0-1)&sta[i])) {
res += num[all^s0];
}
num[sta[i]]++;
}
return res;
}
void solve(int u) {
dp[0] = INF, root = 0;
getroot(0, u);
int s = (1<<a[root]);
ans += getans(root, s);
vis[root] = 1;
int rt = root;
for (int i = 0; i < G[rt].size(); i++) {
int v = G[rt][i];
if (vis[v]) continue;
ans -= getans(v, (s|(1<<a[v])));
SZ = sz[v];
solve(v);
}
}
int main() {
while (~scanf("%d%d", &N, &K)) {
for (int i = 0; i <= N; i++) G[i].clear();
for (int i = 0; i <= N; i++) vis[i] = 0;
all = (1<<K)-1;
for (int i = 1; i <= N; i++) {
scanf("%d", &a[i]);
a[i] -= 1;
}
for (int i = 1; i < N; i++) {
int u, v;
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
if (K == 1) {
printf("%lld\n", 1LL*N*N);
continue;
}
ans = 0;
SZ = N;
solve(1);
printf("%lld\n", ans);
}
return 0;
}