【树上点分治4】 Garden of Eden HDU 5977

原题链接

在这里插入图片描述
题意:给定N个点,K个种类,每个点都有自己的种类,问有多少点对可以覆盖全部的种类

看到k的范围很小,不妨利用状压的思想,将每个种类的点转换为1<<a[i],即寻找当前的和是否达到1<<k-1,如果达到,说明k种齐全。

这里有一个特殊的处理技巧,可以枚举子集的所有状态。

for (int sub = S; sub; sub = (sub - 1) & S) {
    
    
	// sub 为 S 的子集
}

我们同样可以利用hash的思想,因为a | b = c ,我们可以转化为 c ^ b = a。

记得k=1时特判处理。

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <queue>
#include <stack>
#include <cmath>
#include <bitset>
#include <map>
using namespace std;
//#define ACM_LOCAL
typedef long long ll;
typedef long double ld;
typedef pair<int, int> PII;
const int N = 1e5 + 5;
const int INF = 0x3f3f3f3f;
const int MOD = 1e6 + 3;
int n, m, cnt, h[N], rt, sz[N], mx[N], vis[N], sum, k, a[N];
int d[N], dep[N];
int path[1<<10];
ll ans;

struct edge{
    
    
    int to, next;
}e[N<<1];

void add(int u, int v) {
    
    
    e[cnt].to = v;
    e[cnt].next = h[u];
    h[u] = cnt++;
}

void getroot(int x, int fa) {
    
    
    sz[x] = 1, mx[x] = 0;
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (y == fa || vis[y]) continue;
        getroot(y, x);
        sz[x] += sz[y];
        mx[x] = max(mx[x], sz[y]);
    }
    mx[x] = max(mx[x], sum - sz[x]);
    if (mx[x] < mx[rt]) rt = x;
}

void getd(int x, int fa, int now) {
    
    
    d[++d[0]] = now;
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (y == fa || vis[y]) continue;
        getd(y, x, now | (1<<a[y]));
    }
}

ll cal(int x, int now) {
    
    
    ll res = 0;
    d[0] = 0;
    memset(path, 0, sizeof path);
    getd(x, -1, now);
    for (int i = 1; i <= d[0]; i++) path[d[i]]++;
    for (int i = 1; i <= d[0]; i++) {
    
    
        path[d[i]]--;
        res += path[(1<<k)-1];
        for (int j = d[i]; j; j = (j-1) & d[i]) {
    
    
            res += path[((1<<k)-1)^j];
        }
        path[d[i]]++;
    }
    return res;
}

void work(int x) {
    
    
    vis[x] = 1;
    ans += cal(x, 1<<a[x]);
    for (int i = h[x]; ~i; i = e[i].next) {
    
    
        int y = e[i].to;
        if (vis[y]) continue;
        ans -= cal(y, (1<<a[x]) | (1<<a[y]));
        sum = sz[y], rt = 0;
        getroot(y, -1);
        work(rt);
    }
}

void solve () {
    
    
    while (~scanf("%d %d", &n, &k)) {
    
    
        memset(h, -1, sizeof h);
        memset(vis, 0, sizeof vis);
        cnt = 0;
        for (int i = 1; i <= n; i++) scanf("%d", &a[i]), a[i]--;
        for (int i = 1; i < n; i++) {
    
    
            int x, y;
            scanf("%d %d", &x, &y);
            add(x, y);
            add(y, x);
        }
        mx[0] = INF, rt = 0, sum = n, ans = 0;
        getroot(1, -1);
        work(rt);
        if (k == 1) printf("%lld\n", 1ll*n*n);
        else printf("%lld\n", ans);
    }
}

int main() {
    
    
    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
#ifdef ACM_LOCAL
    freopen("input", "r", stdin);
    freopen("output", "w", stdout);
#endif
    solve();
}

猜你喜欢

转载自blog.csdn.net/kaka03200/article/details/109406820