KM算法讲解(含C++代码)

假设有3个女的要嫁给三个男的,各有各的期望值。
如何让期望值之和最大?
此时我们就要用到传说中的km算法了。
这个算法本质上是贪心算法,怎么算呢?
举个例子吧
在这里插入图片描述
首先看女1,女1与男1间的边权值+男1期望值=3,而3不等于女一的期望值,所以配对失败。接着女1与男3间的边权值+男3的期望值=4,刚好4与女1的期望值相等,配对成功!
在这里插入图片描述
接着让女2找对象,匹配的过程就省略了,最后发现跟男3可以配对,而男3被女1占了,女2对女1说:“你能不能降低一下期望值啊?”于是女1同意了,但是我们还是得让女1跟男3可以配对,于是就将女1的期望值降低1,男3的期望值上升1,这样他们还是能够配对(男3挑剔了起来)。但是这样女2又不能配对了,就将女2的期望值也降低。这是发现女1可以和男1配对,就将他们配起来。
在这里插入图片描述
接着帮女3找对象,发现女3无法跟任何人配对,就只好将她的期望值降1,让她可以和男3配对。
在这里插入图片描述

此时女3发现男3被女2占了,于是就勒索让女2降低期望值,于是女2降低1期望值,为了以后还能找男3,所以男3的期望值上升1,女3期望值也得随之降1。这时女1的期望值也得降低1,因为女1也要保持随时可以与男3配对。 接着女2找上了男2,于是就跟男2配对。



此时三男三女都有了自己的对象了。(好开心,终于打完字了)


你以为这就结束了?
不可能 ,还有例题呢:

Description

小W在八中开了一个兼职中心。现在他手下有N个工人。每个工人有N个工作可以选择,于是每个人做每个工作的效率是不一样的。做为CEO的小W的任务就是给每个人分配一个工作,保证所有人效率之和是最大的。N<=200

Input

第一行给出数字N
接下来N行N列,代表每个人工作的效率。

Output

一个数字,代表最大效率之和

Sample Input

4
62 41 86 94
73 58 11 12
69 93 89 88
81 40 69 13

Sample Output

329

HINT



这就是km算法的模板题(虽然跟我举的例子有点出入,但本质上还是相同的)。
代码有点长,慢慢看啊
版本1:

#include<bits/stdc++.h>
using namespace std;
const int N=205;
int w[N][N];
int la[N],lb[N];
bool va[N],vb[N];
int match[N];
int delta,n;
void read() {
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%d",&w[i][j]);
}
bool dfs(int x) {
    va[x]=1;
    for(int y=1;y<=n;y++)
        if(!vb[y])
            if(la[x]+lb[y]-w[x][y]==0) {
                vb[y]=1;
                if(!match[y]||dfs(match[y])) {
                    match[y]=x;
                    return true;
                }
            }
            else
                delta=min(delta,la[x]+lb[y]-w[x][y]);
    return false;
}
int KM() {
    for(int i=1;i<=n;i++) {
        la[i]=-(1<<30);
        lb[i]=0;
        for(int j=1;j<=n;j++)
            la[i]=max(la[i],w[i][j]);
    }
    for(int i=1;i<=n;i++)
        while(true) {
            memset(va,0,sizeof(va));
            memset(vb,0,sizeof(vb));
            delta=1<<30;
            if(dfs(i))
                break;
            for(int j=1;j<=n;j++) {
                if(va[j])
                    la[j]-=delta;
                if(vb[j])
                    lb[j]+=delta;
            }
        }
    int ans=0;
    for(int i=1;i<=n;i++)
        ans+=w[match[i]][i];
    return ans;
}
void write() {
    printf("%d\n",KM());
}
int main() {
    read();
    write();
}

版本2:

#include <bits/stdc++.h>
using namespace std;
int n;
int w[305][305];
int lx[305],ly[305];
int matched[305];
int slack[305];
bool s[305],t[305];
bool match(int i) {
    s[i]=1;
    for(int j=1; j<=n; j++) {
        int cnt=lx[i]+ly[j]-w[i][j];
        if(cnt==0&&!t[j]) {
            t[j]=1;
            if(!matched[j]||match(matched[j])) {
                matched[j]=i;
                return 1;
            }
        } else {
            slack[j]=min(slack[j],cnt);
        }
    }
    return 0;
}
void update() {
    int a=0x3f3f3f3f;
    for(int i=1; i<=n; i++) {
        if(!t[i])
            a=min(a,slack[i]);
    }
    for(int i=1; i<=n; i++) {
        if(s[i])lx[i]-=a;
        if(t[i])ly[i]+=a;
    }
 
}
void km() {
    memset(matched,0,sizeof(matched));
    memset(lx,0,sizeof(lx));
    memset(ly,0,sizeof(ly));
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            lx[i]=max(lx[i],w[i][j]);
        }
    }
    for(int i=1; i<=n; i++) {
        memset(slack,0x3f,sizeof(slack));
        while(1) {
            memset(s,0,sizeof(s));
            memset(t,0,sizeof(t));
            if(match(i))
                break;
            else
                update();
        }
    }
}
int main() {
    scanf("%d",&n);
    for(int i=1; i<=n; i++) {
        for(int j=1; j<=n; j++) {
            scanf("%d",&w[i][j]);
        }
    }
    km();
    int ans=0;
    for(int i=1; i<=n; i++) {
        ans+=lx[i];
        ans+=ly[i];
    }
    printf("%d\n",ans);
 
    return 0;
}

如果有什么地方没讲好或者是说讲错了,可以在评论区告诉我,我会看到后立马修改。
如果觉得有什么想法也可以在评论区说。
感谢观看。

猜你喜欢

转载自blog.csdn.net/liuzich/article/details/105933547