题目链接:loj3119
看到“恰好\(k\)个”就几乎能猜到是容斥相关了
记\(f_i\)为至少有\(i\)个极大的数的方案数,\(Ans_i\)表示恰好有\(i\)个的方案数,则有
\[ f_k=\sum_{i=k}^{min(n,m,l)} \dbinom{k}{i}Ans_i \]
由二项式反演得
\[ Ans_k=\sum_{i=k}^{min(n,m,l)}(-1)^{i-k}\dbinom{k}{i}f_i \]
考虑计算\(f_i\),对于\(i\)个极大的数,它们能影响到的位置也就是与它们中的一个至少有一维坐标相同的位置,剩下的位置可以随意排列,于是我们可以写出\(f_i\)的计算式
\[ f_i=\dbinom{nml}{g_i}b_ih_i(nml-g_i)! \]
\(g_i\)表示与\(i\)个极大的数至少有一维坐标相同的点的个数
\(b_i\)表示选出\(i\)个极大的数的方案数
\(h_i\)表示将\(g_i\)个格子填上\(g_i\)个数的方案数
考虑如何求出上面的三个值
\(g_i\)的计算比较简单,考虑所有坐标均不与极大的数相同的数即可,最后\(g_i=nml-(n-i)(m-i)(l-i)\)
\(b_i和h_i\)d的计算稍显复杂,在这里我们记\(i\)个极大的数分别是\(a_1,a_2,\cdots,a_i\),且\(a_1<a_2<\cdots<a_k\)
首先是\(b_i\),当我们已经选定\(a_1-a_{j-1}\)时,再来考虑\(a_j\)的话,它不能和前\(j-1\)个数有一维坐标是相同的,于是有\(b_i=\prod_{j=0}^{i-1}(n-i)(m-i)(l-i)\)
接下来是\(h_i\),我们依然按照下标顺序来考虑填数,假设我们已经将\(a_1-a_{j-1}\)中的数和与它们至少有一维坐标相同的点填入了立方体中,现在考虑\(a_j\),注意到我们无需考虑与\(a_j\)中至少有一维坐标相同的且已经被填上数的位置,因为由定义已经保证了合法性
所以我们只需考虑那些新增的与\(a_j\)至少有一维坐标相同的点(且除了自身,因为它一定当前所有填入的数的最大值),是\(g_i-g_{i-1}+1\),所以\(h_i=\frac{(g_i-1)!}{(g_{i-1}!)}*h_{i-1}=\prod_{j=1}^i\frac{(g_j-1)!}{(g_{j-1})!}\)
将它们带回原式
\[ \begin{aligned} f_i=&\frac{(nml)!}{(nml-g_i)!g_i!}b_i\prod_{j=1}^i\frac{(g_j-1)!}{g_{j-1}!}(nml-g_i)!\\ =&\frac{nml!}{g_i!}b_i\prod_{j=1}^i(g_j-1)!\prod_{j=0}^{i-1}\frac{1}{g_j!}\\ =&(nml!)b_i\prod_{j=1}^i\frac{1}{g_j} \end{aligned} \]
最后由于求的是概率,直接约掉\(nml!\)
剩下的就是一通乱算了,注意在处理\(g_j\)的逆元时不能暴力快速幂,而应采取类似阶乘求逆元的方法进行求解,具体细节见代码
#include<iostream>
#include<string.h>
#include<string>
#include<stdio.h>
#include<algorithm>
#include<math.h>
#include<vector>
#include<queue>
#include<map>
#include<set>
using namespace std;
#define lowbit(x) (x)&(-x)
#define sqr(x) (x)*(x)
#define fir first
#define sec second
#define rep(i,a,b) for (register int i=a;i<=b;i++)
#define per(i,a,b) for (register int i=a;i>=b;i--)
#define maxd 998244353
#define eps 1e-6
typedef long long ll;
const int N=5000000;
const double pi=acos(-1.0);
int n,m,l,k;
ll fac[5005000],invfac[5005000],f[5005000],g[5005000],invg[5005000];
int read()
{
int x=0,f=1;char ch=getchar();
while ((ch<'0') || (ch>'9')) {if (ch=='-') f=-1;ch=getchar();}
while ((ch>='0') && (ch<='9')) {x=x*10+(ch-'0');ch=getchar();}
return x*f;
}
ll qpow(ll x,ll y)
{
ll ans=1;
while (y)
{
if (y&1) ans=ans*x%maxd;
x=x*x%maxd;y>>=1;
}
return ans;
}
ll C(int n,int m)
{
if (n<m) return 0;
return fac[n]*invfac[m]%maxd*invfac[n-m]%maxd;
}
int main()
{
fac[0]=1;invfac[0]=1;
rep(i,1,N) fac[i]=fac[i-1]*i%maxd;
invfac[N]=qpow(fac[N],maxd-2);
per(i,N-1,1) invfac[i]=invfac[i+1]*(i+1)%maxd;
int T=read();
while (T--)
{
n=read();m=read();l=read();k=read();
int p=min(min(n,m),l);
if (p<k) {puts("0");continue;}
ll tot=1ll*n*m%maxd*l%maxd;
f[0]=1;
rep(i,0,p-1) f[i+1]=f[i]*(n-i)%maxd*(m-i)%maxd*(l-i)%maxd;
ll prog=1;
rep(i,1,p)
{
g[i]=(tot-1ll*(n-i)*(m-i)%maxd*(l-i)%maxd+maxd)%maxd;
prog=prog*g[i]%maxd;
}
invg[p]=qpow(prog,maxd-2);
per(i,p-1,0) invg[i]=invg[i+1]*g[i+1]%maxd;
ll ans=0;
rep(i,k,p)
{
if ((i-k)&1) ans=(ans+maxd-C(i,k)*f[i]%maxd*invg[i]%maxd)%maxd;
else ans=(ans+C(i,k)*f[i]%maxd*invg[i]%maxd)%maxd;
}
printf("%lld\n",ans);
}
return 0;
}