loj2304. 「NOI2017」泳池

题意

略……

题解

以下默认\(k, H = 1001\)同阶。
先差分算两次(计算面积\(\leq k\)的概率)。
考虑dp,设\(dp_{i, j}\)表示当且考虑的矩形宽度为\(j\),从下往上第\(i\)行之下\(i * j\)个格子都是安全时,最大合法区域的面积\(\leq k\)的概率。
状态转移一种是第\(i + 1\)层全部安全,一种是枚举第\(i + 1\)层的第一个不安全点,即
\[ dp_{i, j} = q ^ j * dp_{i + 1, j} + \sum_{t = 1} ^ j (1 - q) q ^ {t - 1} dp_{i + 1, t - 1} dp_{i, j - t} \]
注意到在求出\(i \in [1, k], j \in [0, \frac {k}{i}]\)的所有dp值是\(\mathcal O(k ^ 2)\)的。
但是我们需要的是\(dp_{0, n}\),计算\(dp_{0, i}\)需要很大代价,而\(n\)又很大,怎么办?
我们大力猜一波结论,猜\(dp_{0, i}\)是一个\(2 * k\)阶线性齐次递推,大力BM一发,然后上Cayley-Hamilton即可。
复杂度\(\mathcal O(k ^ 2 \log n)\)

#include <bits/stdc++.h>
using namespace std;
typedef vector <int> poly;
const int H = 2000, N = 2005, mod = 998244353;
int n, k, p, q, po[N], qo[N], dp[N][N];
void U (int &x, int y) {
    if ((x += y) >= mod) {
        x -= mod;
    }
}
int power (int x, int y) {
    int ret = 1;
    for ( ; y; y >>= 1, x = 1ll * x * x % mod) {
        if (y & 1) {
            ret = 1ll * ret * x % mod;
        }
    }
    return ret;
}
namespace polymain {
    void print (const poly &a) {
        for (int i = 0; i < (int)a.size(); ++i) {
            cerr << a[i] << " ";
        }
        cerr << endl;
    }
    poly operator + (const poly &a, const poly &b) {
        poly c(max(a.size(), b.size()));
        for (int i = 0; i < (int)a.size(); ++i) {
            c[i] = a[i];
        }
        for (int i = 0; i < (int)b.size(); ++i) {
            c[i] = (c[i] + b[i]) % mod;
        }
        return c;
    }
    poly operator - (const poly &a, const poly &b) {
        poly c(max(a.size(), b.size()));
        for (int i = 0; i < (int)a.size(); ++i) {
            c[i] = a[i];
        }
        for (int i = 0; i < (int)b.size(); ++i) {
            c[i] = (c[i] - b[i] + mod) % mod;
        }
        return c;
    }
    poly operator << (const poly &a, int b) {
        poly c(a.size() + b);
        for (int i = 0; i < (int)a.size(); ++i) {
            c[i + b] = a[i];
        }
        return c;
    }
    poly operator * (const poly &a, const poly &b) {
        poly c(a.size() + b.size() - 1);
        for (int i = 0; i < (int)a.size(); ++i) {
            for (int j = 0; j < (int)b.size(); ++j) {
                c[i + j] = (c[i + j] + 1ll * a[i] * b[j]) % mod;
            }
        }
        return c;
    }
    poly operator % (const poly &a, const poly &b) {
        poly c = a;
        for (int i = c.size() - 1; i >= (int)b.size() - 1; --i) {
            int v = c[i];
            for (int j = 1; j < (int)b.size(); ++j) {
                U(c[i - j], 1ll * v * b[j] % mod);
            }
        }
        c.resize(b.size() - 1, 0);
        return c;
    }
    poly initpoly (int c) {
        poly r(1);
        return r[0] = c, r;
    }
    poly Berlekamp_Massey (int S[], int n) {
        poly Ci = initpoly(1), Cj = initpoly(1); int b = 1;
        for (int i = 0, j = -1; i < n; ++i) {
            int d = 0;
            for (int j = 0; j < (int)Ci.size(); ++j) {
                d = (1ll * Ci[j] * S[i - j] + d) % mod;
            }
            if (d) {
                poly tmp = Ci;
                Ci = Ci - ((Cj * initpoly(1ll * d * power(b, mod - 2) % mod) << (i - j)));
                if ((int)Cj.size() - j > (int)tmp.size() - i) {
                    Cj = tmp, b = d, j = i;
                }
            }
        }
        return Ci;
    }
    int Cayley_Hamilton (int *a, poly c, int k, int n) {
        poly r = initpoly(1) << 1, s = initpoly(1);
        for ( ; n; n >>= 1, r = r * r % c) {
            if (n & 1) {
                s = s * r % c;
            }
        }
        int ret = 0;
        for (int i = 0; i < k - 1; ++i) {
            U(ret, 1ll * a[i] * s[i] % mod);
        }
        return ret;
    }
}
int calc (int *a, int k, int n) {
    poly c = polymain :: Berlekamp_Massey(a, k + 1);
    for (int i = 0; i < (int)c.size(); ++i) {
        c[i] = (mod - c[i]) % mod;
    }
    return polymain :: Cayley_Hamilton(a, c, c.size(), n);
}
int solve (int k) {
    int h = min(n, H);
    memset(dp, 0, sizeof dp);
    for (int i = 0; i <= k + 1; ++i) {
        dp[i][0] = 1;
    }
    for (int i = k; ~i; --i) {
        for (int j = 1; j <= h && j * i <= k; ++j) {
            U(dp[i][j], 1ll * dp[i + 1][j] * po[j] % mod);
            for (int l = 1; l <= j; ++l) {
                U(dp[i][j], 1ll * dp[i + 1][l - 1] * po[l - 1] % mod * q % mod * dp[i][j - l] % mod);
            }
        }
    }
    if (n <= h) {
        return dp[0][n];
    }
    return calc(dp[0], h, n);
}
int main () {
    cin >> n >> k >> p >> q;
    p = 1ll * p * power(q, mod - 2) % mod, q = (1 + mod - p) % mod;
    po[0] = qo[0] = 1;
    for (int i = 1; i < N; ++i) {
        po[i] = 1ll * po[i - 1] * p % mod;
        qo[i] = 1ll * qo[i - 1] * q % mod;
    }
    cout << (solve(k) - solve(k - 1) + mod) % mod << endl;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/psimonw/p/12031935.html