[SOJ647] 祝著节【贪心】

题意简述:给定一张\(n\)个点,\(m\)条边的无向边带权连通图,每条边可能是黑白两色中的一种。定义一棵生成树合法,当且仅当它包含至少一条黑色边和至少一条白色边。求染色方案数,使得最小合法生成树的边权和恰好为\(x\)\(n\leq 10^5, m\leq 2\times 10^5, x\leq 10^18\)


一个结论:对于任意一个合法的染色方案的任意一个最小合法生成树\(E_1\),一定存在原图中的一个最小生成树\(E_2\),使得\(E_1\)中最多只有一条边不存在\(E_2\)中。

证明:若对于任意的\(E_1, E_2\),均存在\(E_1\)中的两条边\(e_1, e_2\),满足\(e_1, e_2\not \in E_2\),则显然\(len_{e_1}, len_{e_2}\)分别大于它们替换掉的两条边的权值和。不妨设\(E_2\)中的边均为白色,则\(e_1, e_2\)中至少有一条边为黑色。因此,将其中的白边或一条黑边换回\(E_2\)中对应的边,答案一定更优。

因此可以求出原图任意一个最小生成树,记为\(T\),设其权值和为\(s\)

\(s>x\),显然无解。

\(s=x\),显然原图的任意一个最小生成树都可以是合法的。考虑枚举不在\(T\)中的每条边,求出强制其在最小生成树中最优的权值,显然只需替换掉\(T\)中链上权值最大的边即可。假设有\(eq\)条边满足替换后权值和不变,则一个染色方案不合法,当且仅当\(T\)中的\(n-1\)条边和\(eq\)条可替换的边均为同色,答案为\(2^{m}-2^{m-n-eq+2}\)

否则\(s<x\),则原图的任意权值和\(<x\)的生成树均不可能合法。仍然枚举不在\(T\)中的每条边,设有\(eq\)条边满足替换后权值和恰好为\(x\)\(le\)条边满足替换后权值和\(<x\),则显然任意合法的方案均需满足\(T\)中的\(n-1\)条边和\(le\)条替换后权值和\(<x\)的边均同色。在此前提下,一个染色方案不合法当且仅当\(eq\)条可替换的边均与上述边同色,答案为\(2^{m-n-le+2}-2^{m-n-le-eq+2}\)

然后我倍增写挂了调了半天

#include <cstdio>
#include <cctype>
#include <cstring>
#include <cassert>
#include <iostream>
#include <algorithm>
#define R register
#define ll long long
using namespace std;
const int N = 110000, M = 210000, mod = 1e9 + 7;

int t, n, m, q, hd[N], nxt[M], to[M], val[M], noedg, st[N][17], maxV[N][17], dep[N];
ll lim, sum;
struct node {
    int x, y, w, used;
    inline bool operator < (const node &a) const {
        return w < a.w;
    }
}edg[M];

namespace dsu {
///
int fa[N];

inline void init(int n) {
    for (R int i = 1; i <= n; ++i)
        fa[i] = i;
    return;
}

int find(int x) {
    return fa[x] == x ? x : (fa[x] = find(fa[x]));
}

inline void unite(int x, int y) {
    fa[find(x)] = find(y);
    return;
}
///
}

inline void init() {
    dsu::init(n), noedg = 1, sum = 0;
    for (R int i = 1; i <= n; ++i)
        hd[i] = 0;
    return;
}

template <class T> inline void read(T &x) {
    x = 0;
    char ch = getchar(), w = 0;
    while (!isdigit(ch)) w = (ch == '-'), ch = getchar();
    while (isdigit(ch)) x = (x << 1) + (x << 3) + (ch ^ 48), ch = getchar();
    x = w ? -x : x;
    return;
}

inline void addEdg(int x, int y, int w) {
    nxt[++noedg] = hd[x], hd[x] = noedg, to[noedg] = y, val[noedg] = w;
    nxt[++noedg] = hd[y], hd[y] = noedg, to[noedg] = x, val[noedg] = w;
    return;
}

void dfs1(int now) {
    dep[now] = dep[st[now][0]] + 1;
    for (R int i = 1; i <= 16; ++i) {
        maxV[now][i] = max(maxV[now][i - 1], maxV[st[now][i - 1]][i - 1]);
        st[now][i] = st[st[now][i - 1]][i - 1];
    }
    for (R int i = hd[now], v; i; i = nxt[i]) {
        if ((v = to[i]) == st[now][0]) continue;
        maxV[v][0] = val[i], st[v][0] = now, dfs1(v);
    }
    return;
}

inline int findMax(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    int ret = 0;
    for (R int i = 16; ~i; --i)
        if (dep[st[x][i]] >= dep[y])
            ret = max(ret, maxV[x][i]), x = st[x][i];
    if (x == y) return ret;
    for (R int i = 16; ~i; --i)
        if (st[x][i] != st[y][i])
            ret = max(ret, max(maxV[x][i], maxV[y][i])), x = st[x][i], y = st[y][i];
    return max(ret, max(maxV[x][0], maxV[y][0]));
}

inline ll quickpow(ll base, ll pw) {
    ll ret = 1;
    while (pw) {
        if (pw & 1) ret = ret * base % mod;
        base = base * base % mod, pw >>= 1;
    }
    return ret;
}

int main() {
    read(t);
    int x, y;
    while (t--) {
        read(n), read(m), read(lim);
        init();
        for (R int i = 1; i <= m; ++i)  
            read(edg[i].x), read(edg[i].y), read(edg[i].w), edg[i].used = 0;
        sort(edg + 1, edg + 1 + m);
        for (R int i = 1; i <= m; ++i) {
            x = edg[i].x, y = edg[i].y;
            if (dsu::find(x) != dsu::find(y))
                dsu::unite(x, y), edg[i].used = 1, sum += edg[i].w, addEdg(x, y, edg[i].w);
        }
        if (sum > lim) {
            printf("0\n");
            continue;
        }
        dfs1(1);
        int numL = 0, numE = 0;
        for (R int i = 1; i <= m; ++i) {
            if (edg[i].used) continue;
            int v = findMax(edg[i].x, edg[i].y);
            if (sum - v + edg[i].w < lim)
                ++numL;
            else if (sum - v + edg[i].w == lim)
                ++numE;
        }
        if (lim == sum)
            printf("%lld\n", (quickpow(2, m) + mod - quickpow(2, m - n + 2 - numE)) % mod);
        else
            printf("%lld\n", (quickpow(2, m - n + 2 - numL) - quickpow(2, m - n + 2 - numL - numE) + mod) % mod);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/suwakow/p/11720873.html