@noi.ac - 507@ 二分图最大权匹配


@description@

有一天你学了一个能解决二分图最大权匹配的算法,你决定将这个算法应用到NOI比赛中。

给定一张完全二分图。在这张图里,两个部分的的大小均为 n。对于第一部分的点 u 和第二部分的点 v ,连接它们的边的权值为 \(c_{uv}+k_{uv}x\),其中 x 为一个值不确定的变量。

你将被多次给定 x 的值,对于每一个 x 的值,你需要回答对应的二分完全图的最大权匹配的总权值。

input
第一行一个数n,含义如题所述。

接下来 n 行,每行 n 个整数,其中第 i 行的第 j 个数为 cij 的值。

接下来nn行,每行nn个整数,其中第 i 行的第 j 个数为 kij 的值。

接下来一个数 q,表示给定的 x 的值的数量。

接下来 q 行,每行一个整数表示给定的 x 的值。

output
输出 q 行,每行一个数,其中第 i 行表示对应于第 i 个 x 的值的答案。

sample input
3
0 0 2
0 2 0
2 0 0
0 1 0
1 1 0
0 0 1
3
0
1
3
sample output
6
7
9

explanation
对于 x=0,最大匹配为 0→2,1→1,2→0,答案为 2+2+2=6。

对于 x=1,最大匹配仍然为 0→2,1→1,2→0,由于只有 \(k_{11}=1\)而另外两条边的 k 为零,只有第二条边的边权有变化,为 3。答案为2+3+2=7。

对于 x=3,最大匹配变为 0→1,1→0,2→2,因为这三条边的 k 值均为1,边权均变成了3。答案为 3+3+3=9。

对于100%的数据,1≤n≤50,1≤q≤100000,0≤cij≤10^7,0≤kij≤1,给定的 x 的值为在 0 到 10^7 之间的整数。

@solution@

不难发现答案一定形如 K*x + C 的形式。因为匹配最多 n 条边,所以 1<=K<=n。
进一步发现,每一个 K 唯一对应一个 C。这意味着答案的变化只会存在 n 种可能性。
我们只需要找出这 n 种可能性,再对于每一个询问找该询问对应哪一种可能性即可。

直观上可以发现(同时也不难使用反证法证明),当 x 越大时,最优解对应的 K 一定随之增大。
这意味着每一种 K 对应的 x 总是一段连续的区间,于是我们可以通过二分 x 的值找到每一种 K 对应的区间。

考虑求解最大权匹配时使用 KM 算法 O(n^3) 求解,我们的预处理时间复杂度 O(n^4*log(A)),其中 A = 10^7 是一个常数。
而询问只需要遍历 n 种可能性,故询问总时间复杂度为 O(nq)。
只要你的 KM 算法写的真的是 O(n^3)(可以去 uoj#80 测一测)而不是 O(n^4) 就可以过。

其实真正用得到的 x 只有 q 次询问中的 x。我们可以将询问排序后在询问上二分,就可以将时间复杂度将至 O(n^4*log(q))
但我懒得写这个优化。

@accepted code@

#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
const int INF = 1<<30;
int c[50 + 5][50 + 5], k[50 + 5][50 + 5], K, X, C;
int f(int x, int y) {return k[x][y]*X + c[x][y];}
int lx[50 + 5], ly[50 + 5], lk[50 + 5], slk[50 + 5];
bool vx[50 + 5], vy[50 + 5];
int n, q;
bool dfs(int x) {
    vx[x] = true;
    for(int y=1;y<=n;y++) {
        if( vy[y] ) continue;
        int t = lx[x] + ly[y] - f(x, y);
        if( t == 0 ) {
            vy[y] = true;
            if( (!lk[y]) || dfs(lk[y]) ) {
                lk[y] = x;
                return true;
            }
        }
        else slk[y] = min(slk[y], t);
    }
    return false;
}
void KM(int p) {
    X = p;
    for(int i=1;i<=n;i++) {
        lx[i] = ly[i] = lk[i] = 0;
        for(int j=1;j<=n;j++)
            lx[i] = max(lx[i], f(i, j));
    }
    for(int x=1;x<=n;x++) {
        for(int i=1;i<=n;i++)
            vx[i] = vy[i] = false, slk[i] = INF;
        if( !dfs(x) ) {
            while( true ) {
                int del = INF, y = 0;
                for(int i=1;i<=n;i++)
                    if( !vy[i] ) del = min(del, slk[i]);
                for(int i=1;i<=n;i++) {
                    if( vx[i] ) lx[i] -= del;
                    if( vy[i] ) ly[i] += del;
                }
                for(int i=1;i<=n;i++)
                    if( !vy[i] ) {
                        slk[i] -= del;
                        if( slk[i] == 0 )
                            y = i;
                    }
                if( !lk[y] ) break;
                vx[lk[y]] = vy[y] = true;
                for(int i=1;i<=n;i++)
                    slk[i] = min(slk[i], lx[lk[y]] + ly[i] - f(lk[y], i));
            }
            for(int i=1;i<=n;i++)
                vx[i] = vy[i] = false;
            dfs(x);
        }
    }
    K = C = 0;
    for(int i=1;i<=n;i++)
        K += k[lk[i]][i], C += c[lk[i]][i];
}
struct node{
    int k, c, l, r;
    node(int _k=0, int _c=0, int _l=0, int _r=0):k(_k), c(_c), l(_l), r(_r){}
};
vector<node>vec;
int pk, pc;
int solve(int l, int r) {
    while( l < r ) {
        int mid = (l + r + 1) >> 1; KM(mid);
        if( pk == K ) l = mid;
        else r = mid - 1;
    }
    return l;
}
int main() {
    scanf("%d", &n);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d", &c[i][j]);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d", &k[i][j]);
    int le = 0, ri = int(1E7);
    while( le <= ri ) {
        KM(le); pk = K, pc = C;
        int res = solve(le, ri);
        vec.push_back(node(pk, pc, le, res));
        le = res + 1;
    }
    scanf("%d", &q);
    for(int i=1;i<=q;i++) {
        int x; scanf("%d", &x);
        for(int j=0;j<vec.size();j++)
            if( vec[j].l <= x && x <= vec[j].r )
                printf("%d\n", vec[j].k*x + vec[j].c);
    }
}

@details@

本身来说这道题难度不大,只是需要想到一开始 K*x + C 中 K 的取值只有 n 种。

另外网上百度出来的 KM 算法大多都是假的 O(n^3)(包括百度百科),如果真的想找 O(n^3) 的代码可以去搜 uoj#80 的题解。
其实我也不知道把 KM 算法换成费用流能不能过。虽然费用流的复杂度是 O(玄学),不过完全图应该跑不快。。。

猜你喜欢

转载自www.cnblogs.com/Tiw-Air-OAO/p/11112549.html