【JSOI2018】潜入行动(树形动态规划)

题目链接

【JSOI2018】潜入行动


题目大意

n 个节点的树上大小为 k 的覆盖集个数。本题中一个被标记的点可以覆盖所有与之相邻的点,但是不能覆盖其本身。
n 100000 m 100


题解

d p [ v e r ] [ c n t ] [ 0 / 1 ] [ 0 / 1 ] 表示节点 v e r 的子树中选出 c n t 个节点,其中选了/不选节点 v e r v e r 是/否被覆盖的方案数量。
状态转移方程见代码。
但是,暴力转移会超时。
那就加一个优化:每个节点 v e r 的子树中最多选 s i z e [ v e r ] 个节点。
算一下 s i z e [ v e r ] 就可以啦。


代码

#include <cstdio>
const int maxn = 100005;
const int maxm = 105;
const int maxe = 200005;
const int mod = 1e9+7;
typedef unsigned long long ull;
int n, m, sz[maxn], dp[maxn][maxm][2][2];
int tot, ter[maxe], nxt[maxe], lnk[maxn];
ull cur[maxm][2][2];
inline void upd(int &a, int b) {
    a += b, a -= (a >= mod ? mod : 0);
}
inline int min(const int &a, const int &b) {
    return a < b ? a : b;
}
void addedge(int u, int v) {
    ter[++tot] = v;
    nxt[tot] = lnk[u];
    lnk[u] = tot;
}
void treedp(int u, int p) {
    sz[u] = 1;
    dp[u][0][0][0] = 1;
    dp[u][1][1][0] = 1;
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p) continue;
        treedp(v, u);
        int bound0 = min(m, sz[u]), bound1 = min(m, sz[v]);
        for (int x = 0; x <= bound0; x++)
            for (int x0 = 0; x0 <= 1; x0++)
                for (int x1 = 0; x1 <= 1; x1++)
                    cur[x][x0][x1] = dp[u][x][x0][x1], dp[u][x][x0][x1] = 0;
        for (int x0 = 0; x0 <= bound0; x0++)
            for (int x1 = 0; x1 <= bound1 && x0 + x1 <= m; x1++) {
                upd(dp[u][x0 + x1][0][0], cur[x0][0][0] * dp[v][x1][0][1] % mod);
                upd(dp[u][x0 + x1][0][1], (cur[x0][0][0] * dp[v][x1][1][1] + cur[x0][0][1] * (dp[v][x1][0][1] + dp[v][x1][1][1])) % mod);
                upd(dp[u][x0 + x1][1][0], cur[x0][1][0] * (dp[v][x1][0][0] + dp[v][x1][0][1]) % mod);
                upd(dp[u][x0 + x1][1][1], (cur[x0][1][0] * (dp[v][x1][1][0] + dp[v][x1][1][1]) + cur[x0][1][1] * (dp[v][x1][0][0] + dp[v][x1][1][0]) + cur[x0][1][1] * (dp[v][x1][0][1] + dp[v][x1][1][1])) % mod);
            }
        sz[u] += sz[v];
    }
    /*
    我可以用来 debug 程序!
    for (int x = 0; x <= m; x++)
        for (int x0 = 0; x0 <= 1; x0++)
            for (int x1 = 0; x1 <= 1; x1++)
                printf("dp[%d][%d][%d][%d] = %d\n", u, x, x0, x1, dp[u][x][x0][x1]);
    */
}
int main() {
    scanf("%d %d", &n, &m);
    for (int u, v, i = 1; i < n; i++) {
        scanf("%d %d", &u, &v);
        addedge(u, v), addedge(v, u);
    }
    treedp(1, 0);
    printf("%d\n", (dp[1][m][0][1] + dp[1][m][1][1]) % mod);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_42068627/article/details/80384000