题目大意
略
题解
显然应该尽量先把攻击牌放完。
然后分情况讨论大力 DP 即可。
在写这题的时候被两个问题坑了好久:
- \(k=1\)。我的计数方式没法处理这个简单问题。。。如果是在考场上的话应该多想想这样的边界情况。
- 数组越界。开数组比较吝啬,导致访问一些下标爆范围但是没用的数据时会越界算错答案。
实现
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int rint() {
int n, c;
while ((c = getchar()) < '0');
n = c - '0';
while ((c = getchar()) >= '0') n = 10 * n + c - '0';
return n;
}
template <int MD>
struct ModInt {
typedef ModInt M;
int v;
ModInt() : v(0) {}
ModInt(int _v) : v(_v) {}
M& operator += (const M &r) {
if ((v += r.v) >= MD) v -= MD;
return *this;
}
M& operator *= (const M &r) {
v = ll(v) * r.v % MD;
return *this;
}
M operator + (const M &r) const { return M(*this) += r; }
M operator * (const M &r) const { return M(*this) *= r; }
};
typedef ModInt<998244353> Mint;
const int N = 3010;
int n, m, k;
int a[N], b[N];
Mint C[2 * N][2 * N];
Mint dp[N][N], f[N][N], g[N][N];
Mint sum[N];
void first() {
for (int i = 0; i < 2 * N; i++) {
for (int j = 0; j <= i; j++) {
if (j == 0) C[i][j] = Mint(1);
else C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
}
}
}
void solve() {
scanf("%d %d %d", &n, &m, &k);
for (int i = 0; i < n; i++) a[i] = rint();
for (int i = 0; i < n; i++) b[i] = rint();
sort(a, a + n, greater<int>());
sort(b, b + n, greater<int>());
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= n; j++) {
dp[i][j] = Mint(0);
f[i][j] = Mint(0);
}
}
dp[0][0] = Mint(1);
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
dp[i + 1][j] += dp[i][j];
Mint x = Mint(a[i]) * dp[i][j];
dp[i + 1][j + 1] += x;
f[i + 1][j + 1] += x;
}
}
for (int i = 0; i <= n; i++) {
for (int j = 0; j <= n; j++) {
dp[i][j] = Mint(0);
g[i][j] = Mint(0);
}
}
for (int i = 0; i < n; i++) {
for (int j = 0; j <= i; j++) {
dp[i + 1][j] += dp[i][j];
Mint x = C[i][j] * Mint(b[i]) + dp[i][j];
dp[i + 1][j + 1] += x;
g[i + 1][j + 1] += x;
}
}
if (k == 1) {
Mint sm(0);
for (int i = 1; i <= n; i++) {
sm += Mint(b[i - 1]) * C[2 * n - i][m - k];
}
printf("%d\n", sm.v);
return;
}
fill(sum, sum + n + 1, Mint(0));
for (int i = 1; i <= n; i++) {
for (int x = 1; x <= n; x++) {
sum[x] += g[i][x] * C[n - i][m - k];
}
}
Mint sm = sum[k];
// 0 < t < k - 1
for (int i = 1; i <= n; i++) {
for (int t = 1; t < k - 1; t++) {
sm += f[i][t] * sum[k - t];
}
}
// t >= k - 1
for (int i = 1; i <= n; i++) {
Mint d(0);
for (int j = 1; j <= n; j++) {
d += Mint(b[j - 1]) * C[2 * n - i - j][m - k];
}
sm += d * f[i][k - 1];
}
printf("%d\n", sm.v);
}
int main() {
first();
int tc;
scanf("%d", &tc);
while (tc--) solve();
return 0;
}