[bzoj5093][Lydsy1711月赛]图的价值【FFT~NTT】【stirling数】【二项式反演】

【题目链接】
  https://www.lydsy.com/JudgeOnline/problem.php?id=5093
【题解】
  首先每个点都是独立的,可以求出一个点的贡献再把它乘以 n ,枚举这个点连了多少条边,可以列出式子:
   a n s = n 2 ( 2 n 1 ) i = 0 n 1 ( i n 1 ) i k
  考虑第二类斯特林数的一个恒等式:
   X n = i = 0 X ( i X ) i ! S n , i
  将它代入:
   a n s = n 2 ( 2 n 1 ) i = 0 n 1 ( i n 1 ) j = 0 i ( j i ) j ! S k , j
  现在考虑求: i = 0 n 1 ( i n 1 ) j = 0 i ( j i ) j ! S k , j
  先改变求和顺序: j = 0 n 1 j ! S k , j i = j n 1 ( i n 1 ) ( j i )
  考虑后面的和式的意义:先从 n 1 个数中取出 j 个数,再从 j 个数中取出 i 个数。
  因此可以化为:从 n 1 个数中选出 j 个数,其他的数是否选取随意。于是可以转换为:
   j = 0 n 1 j ! S k , j ( j n 1 ) 2 n 1 j
  由于 S k , n ( n > k ) = 0 所以只要求第 k 行的前 k + 1 个斯特林数即可。
  对于恒等式做二项式反演:
   i ! S k , i = j = 0 i ( 1 ) i j ( j i ) i k
   S k , i = j = 0 i ( 1 ) i j ( i j ) ! i k i !
  NTT即可。
  时间复杂度 O ( K l o g K )
【代码】

/* - - - - - - - - - - - - - - -
    User :      VanishD
    problem :   [bzoj5093]
    Points :    NTT + stirling
- - - - - - - - - - - - - - - */
# include <bits/stdc++.h>
# define    ll      long long
# define    N       1000100
using namespace std;
const int inf = 0x3f3f3f3f, INF = 0x7fffffff, P = 998244353, G = 3; 
const ll  infll = 0x3f3f3f3f3f3f3f3fll, INFll = 0x7fffffffffffffffll;
int read(){
    int tmp = 0, fh = 1; char ch = getchar();
    while (ch < '0' || ch > '9'){ if (ch == '-') fh = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9'){ tmp = tmp * 10 + ch - '0'; ch = getchar(); }
    return tmp * fh;
}
int power(int x, ll y){
    int i = x; x = 1;
    while (y > 0){
        if (y % 2 == 1) x = 1ll * i * x % P;
        i = 1ll * i * i % P;
        y /= 2;
    }
    return x;
}
void NTT(int *a, int l, int tag){
    for (int i = 0, j = 0; i < l; i++){
        if (i < j) swap(a[i], a[j]);
        for (int k = (l >> 1); (j ^= k) < k; k >>= 1);
    }
    for (int i = 1; i < l; i *= 2){
        int wn = power(G, (P - 1) / (i * 2));
        if (tag == -1) wn = power(wn, P - 2);
        for (int j = 0; j < l; j += i * 2)
            for (int k = 0, w = 1; k < i; k++, w = 1ll * w * wn % P){
                int x = a[k + j], y = 1ll * w * a[k + i + j] % P;
                a[k + j] = (x + y) % P; a[k + i + j] = (x - y) % P;
            }
    }
    if (tag == -1){
        int inv = power(l, P - 2);
        for (int i = 0; i < l; i++) a[i] = 1ll * a[i] * inv % P;
    }
} 
int powk[N], mul[N], a[N], b[N], C[N], n, k, l, S[N];
void getS(int n){
    powk[0] = 1, mul[0] = 1;
    for (int i = 1; i <= n; i++) mul[i] = 1ll * mul[i - 1] * i % P;
    for (int i = 0; i <= n; i++){
        a[i] = power(-1, i) * power(mul[i], P - 2);
        b[i] = 1ll * power(i, n) * power(mul[i], P - 2) % P;
    }
    l = 1;
    while (l <= n * 2) l <<= 1;
    NTT(a, l, 1), NTT(b, l, 1);
    for (int i = 0; i < l; i++) a[i] = 1ll * a[i] * b[i] % P;
    NTT(a, l, -1);
    for (int i = 0; i <= n; i++) S[i] = (a[i] + P) % P; 
}
int main(){
//  freopen(".in", "r", stdin);
//  freopen(".out", "w", stdout);
    n = read(), k = read();
    getS(k); 
    int lim = min(n - 1, k), ans = 0;
    C[0] = 1; for (int i = 1; i <= k; i++) C[i] = 1ll * C[i - 1] * (n - i) % P * power(i, P - 2) % P;
    for (int i = 0; i <= lim; i++)
        ans = (ans + 1ll * S[i] * mul[i] % P * C[i] % P * power(2, n - 1 - i)) % P; 
    ans = (ans + P) % P;
    ll tmp = 1ll * (n - 1)* (n - 2) / 2;
    ans = 1ll * ans * n % P *power(2, tmp) % P;
    printf("%d\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/d_vanisher/article/details/80819424