【bzoj3622】已经没有什么好害怕的了 【容斥原理】

题目传送门
题解
题目有一个条件:2n个数两两不同,所以不用考虑相等的情况。首先我们设有x对a比b大,y对b比a大。
可以得到

{ x = y + k x + y = n

解得 x = n + k 2
如果x不是整数直接输出0就行了。
否则我们考虑dp+容斥。
注意,下文中的配对都是指a比b大的配对。
首先把a和b排个序。
我们让 f [ i ] [ j ] 表示a中的前i个,有j个配对了的方案总数。
有状态转移方程
f [ i ] [ j ] = f [ i 1 ] [ j ] + f [ i 1 ] [ j 1 ] × ( r i g h t [ i ] ( j 1 ) )

r i t g h [ i ] 表示最大的 j ,使得 b [ j ] < a [ i ]
这样的话,至少 i 个配对的方案总数就是
f [ n ] [ i ] ( n i ) !

相当于把没有配对的部分做一次全排列。
我们再设 g [ i ] 表示恰好 i 对配对的方案总数。
f [ n ] [ i ] ( n i ) ! 会算入许多多于 i 对配对的方案总数,我们要考虑把他们去掉。
考虑恰好 j ( i < j n ) 种配对,即 g [ j ] 会在 f [ n ] [ i ] ( n i ) ! 被算入多少次。
如果被算入,那一定是在前 i 的配对中配出了 g 的一部分,剩下的部分被全排列到了,全排列到的次数有且仅有一次。
i 个配对被 j 个配对完全包含只会出现 C j i 次,因此每个 g [ j ] 的贡献都被计算了 C j i
因此有容斥方程
g [ i ] = f [ n ] [ i ] × ( n i ) ! j = i + 1 n g [ j ] C j i

我们只需要在求出 f 数组后倒着求一次 g 就行了。
代码

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
const int N=2005,mod=1000000009;
int n,k,a[N],b[N],jc[N],ijc[N],right[N],f[N][N],g[N];
int fastpow(int a,int x){
    int res=1;
    while(x){
        if(x&1){
            res=1LL*res*a%mod;
        }
        x>>=1;
        a=1LL*a*a%mod;
    }
    return res;
}
int C(int n,int m){
    return 1LL*jc[n]*ijc[m]%mod*ijc[n-m]%mod;
}
int main(){
    scanf("%d%d",&n,&k);
    if((n+k)&1){
        puts("0");
        return 0;
    }
    k=(n+k)/2;
    jc[0]=ijc[0]=1;
    for(int i=1;i<=n;i++){
        jc[i]=1LL*jc[i-1]*i%mod;
        ijc[i]=fastpow(jc[i],mod-2);
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&a[i]);
    }
    for(int i=1;i<=n;i++){
        scanf("%d",&b[i]);
    }
    sort(a+1,a+n+1);
    sort(b+1,b+n+1);
    for(int i=1,j=0;i<=n;i++){
        while(j<n&&b[j+1]<a[i]){
            j++;
        }
        right[i]=j;
    }
    f[0][0]=1;
    for(int i=1;i<=n;i++){
        for(int j=0;j<=i;j++){
            f[i][j]=f[i-1][j];
            if(j){
                f[i][j]=(f[i][j]+1LL*f[i-1][j-1]*(right[i]-(j-1))%mod)%mod;
            }
        }
    }
    for(int i=n;i>=k;i--){
        g[i]=1LL*f[n][i]*jc[n-i]%mod;
        for(int j=i+1;j<=n;j++){
            g[i]=(g[i]-1LL*g[j]*C(j,i)%mod+mod)%mod;
        }
    }
    printf("%d\n",g[k]);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/ez_2016gdgzoi471/article/details/81384153