清华集训2017 生成树计数

题意:

给定 n n 个连通块,每个连通块的大小为 a i a_i ,接下来依次连 n 1 n-1 条边,得到的树 T T 的价值定义为:
v a l ( T ) = ( i = 1 n d i m ) ( i = 1 n d i m ) val(T)=\left(\prod_{i=1}^nd_i^m\right)\left(\sum_{i=1}^nd_i^m\right)
其中, d i d_i 表示与第 i i 个连通块连接的边的条数。请求出所有不同连边方式产生的树的价值和膜 998244353 998244353 .
n 30 , 000 , m 30 n\le30,000,m\le30

前置技能:求数列 k k 次方和。

给定 k k ,对于任意的 0 t k 0\le t\le k ,求出 i = 1 n a i t \sum\limits_{i=1}^na_i^t k , n 1 0 5 k,n\le 10^5
考虑答案的生成函数 F ( x ) = t = 0 k x t i = 1 n a i t = i = 1 n 1 1 x a i F(x)=\sum\limits_{t=0}^kx^t\sum\limits_{i=1}^na_i^t=\sum\limits_{i=1}^n\frac{1}{1-xa_i} .
直接计算仍然是不行的,注意到 ln ( 1 a i x ) = a i 1 a i x = t = 0 ( a i x ) t a i \ln'(1-a_ix)=\frac{-a_i}{1-a_ix}=-\sum\limits_{t=0}^\infty(a_ix)^ta_i
因此考虑先计算 G ( x ) = t = 0 k x t i = 1 n a i t + 1 G(x)=-\sum\limits_{t=0}^kx^t\sum\limits_{i=1}^na_i^{t+1} ,则 F ( x ) = x G ( x ) + n F(x)=-xG(x)+n 。化简 G ( x ) G(x)
G ( x ) = i = 1 n ln ( 1 a i x ) = ln ( i = 1 n ( 1 a i x ) ) G(x)=\sum\limits_{i=1}^n\ln'(1-a_ix)=\ln'\left(\prod_{i=1}^n(1-a_ix)\right)
括号内的东西分治NTT即可,然后多项式求ln再求导,即可得到 F ( x ) F(x)

过程

对于每个终方案 T T ,对答案的贡献为 i = 1 n a i d i d i m i = 1 n d i m \prod\limits_{i=1}^na_i^{d_i}d_i^m\sum\limits_{i=1}^nd_i^m
由于出现了度数,我们考虑使用prufer序列化简算式。总贡献等价于:
( n 2 ) ! d i = n 2 i = 1 n a i d i + 1 d i ! ( d i + 1 ) m i = 1 n ( d i + 1 ) m = ( n 2 ) ! i = 1 n a i d i = n 2 i = 1 n a i d i d i ! ( d i + 1 ) m i = 1 n ( d i + 1 ) m (n-2)!\sum_{\sum d_i=n-2}\prod_{i=1}^n\frac{a_i^{d_i+1}}{d_i!}(d_i+1)^m\sum_{i=1}^n(d_i+1)^m \\ =(n-2)!\prod_{i=1}^na_i\sum_{\sum d_i=n-2}\prod_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^m\sum_{i=1}^n(d_i+1)^m
前面的 ( n 2 ) ! i = 1 n a i (n-2)!\prod\limits_{i=1}^na_i 是常量,我们不需要关注。考虑后面的东西,它等价于:
i = 1 n a i d i d i ! ( d i + 1 ) 2 m j = 1 , j i n a j d j d j ! ( d j + 1 ) m \sum_{i=1}^n\frac{a_i^{d_i}}{d_i!}(d_i+1)^{2m}\prod_{j=1,j\neq i}^n\frac{a_j^{d_j}}{d_j!}(d_j+1)^m
考虑构建上式关于 d i \sum d_i 的生成函数 F ( x ) F(x) 。考虑下述的两个多项式:
A ( x ) = i x i ( i + 1 ) m i ! A(x)=\sum_i \frac{x^i(i+1)^m}{i!}
B ( x ) = i x i ( i + 1 ) 2 m i ! B(x)=\sum_i \frac{x^i(i+1)^{2m}}{i!}
则有:
F ( x ) = i B ( a i x ) j i A ( a j x ) = i B ( a i x ) A ( a i x ) j A ( a j x ) = i B ( a i x ) A ( a i x ) exp j ln A ( a j x ) F(x)=\sum_i B(a_ix)\prod_{j\neq i}A(a_jx)=\sum_i\frac{B(a_ix)}{A(a_ix)}\prod_jA(a_jx) \\ =\sum_i\frac{B(a_ix)}{A(a_ix)}\exp\sum_j\ln A(a_jx)
也就是说求出 B ( x ) A ( x ) ln A ( x ) \frac{B(x)}{A(x)}和\ln A(x) 后,需要对于每一项乘上 a i k \sum a_i^k ,这正是我们前面说过可以在 O ( n l o g 2 n ) O(nlog^2n) 的时间内求出的东西。因此最后的复杂度为 O ( n l o g 2 n ) O(nlog^2n)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int maxn = 65540, mod = 998244353;
ll ta[maxn], tb[maxn], tc[maxn];
ll modpow(ll a, int b) {
	ll res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = res * a % mod;
		a = a * a % mod;
	}
	return res;
}
void rader(ll *a, int n) {
	for (int i = 1, j = n >> 1; i < n - 1; i++) {
		if (i < j) swap(a[i], a[j]);
		int k = n >> 1;
		for (; j >= k; k >>= 1) j -= k;
		if (j < k) j += k;
	}
}
void ntt(ll *a, int n, int rev) {
	rader(a, n);
	for (int h = 2; h <= n; h <<= 1) {
		ll wn = modpow(3, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
		int hh = h >> 1;
		for (int i = 0; i < n; i += h)
		for (int j = i, w = 1; j < i + hh; j++, w = w * wn % mod) {
			const int x = a[j], y = a[j + hh] * w % mod;
			a[j] = (x + y) % mod;
			a[j + hh] = (x - y + mod) % mod;
		}
	}
	if (rev) {
		int inv = modpow(n, mod - 2);
		for (int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
	}
}
void get_inv(ll *a, ll *b, int n) {
	if (n == 1) { b[0] = modpow(a[0], mod - 2); b[1] = 0; return; }
	get_inv(a, b, n >> 1);
	int m = n << 1;
	for (int i = n; i < m; i++) ta[i] = b[i] = 0;
	for (int i = 0; i < n; i++) ta[i] = a[i];
	ntt(ta, m, 0), ntt(b, m, 0);
	for (int i = 0; i < m; i++) b[i] = (mod + 2 - ta[i] * b[i] % mod) * b[i] % mod;
	ntt(b, m, 1);
	for (int i = n; i < m; i++) b[i] = 0;
}
void get_ln(ll *a, ll *b, int n) {
	get_inv(a, tb, n);
	for (int i = 1; i < n; i++) b[i - 1] = a[i] * i % mod;
	b[n - 1] = 0;
	int m = n << 1;
	for (int i = n; i < m; i++) b[i] = 0;
	ntt(b, m, 0), ntt(tb, m, 0);
	for (int i = 0; i < m; i++) b[i] = b[i] * tb[i] % mod;
	ntt(b, m, 1);
	for (int i = n; i < m; i++) b[i] = 0;
	for (int i = n - 1; i > 0; i--) b[i] = b[i - 1] * modpow(i, mod - 2) % mod;
	b[0] = 0;
}
void get_exp(ll *a, ll *b, int n) {
	if (n == 1) { b[0] = 1, b[1] = 0; return; }
	get_exp(a, b, n >> 1);
	get_ln(b, tc, n);
	int m = n << 1;
	for (int i = n; i < m; i++) b[i] = ta[i] = 0;
	for (int i = 0; i < n; i++) ta[i] = (mod + !i + a[i] - tc[i]) % mod;
	ntt(b, m, 0), ntt(ta, m, 0);
	for (int i = 0; i < m; i++) b[i] = b[i] * ta[i] % mod;
	ntt(b, m, 1);
	for (int i = n; i < m; i++) b[i] = 0;
}
ll fac[maxn], rev[maxn], A[maxn], B[maxn], C[maxn], sum[maxn];
int sz[maxn], n, m;
vector<ll> divide(int l, int r) {
	if (l == r) { vector<ll> vec; vec.push_back(1); vec.push_back(mod - sz[l]); return vec; }
	int mid = (l + r) >> 1, len = 1;
	vector<ll> vl = divide(l, mid), vr = divide(mid + 1, r);
	while (len <= r - l) len <<= 1; len <<= 1;
	for (int i = 0; i < len; i++) {
		ta[i] = tb[i] = 0;
		if (i <= mid - l + 1) ta[i] = vl[i];
		if (i <= r - mid) tb[i] = vr[i];
	}
	ntt(ta, len, 0), ntt(tb, len, 0);
	for (int i = 0; i < len; i++) ta[i] = ta[i] * tb[i] % mod;
	ntt(ta, len, 1);
	vector<ll> res;
	for (int i = 0; i <= r - l + 1; i++) res.push_back(ta[i]);
	return res;
}
int main() {
	scanf("%d%d", &n, &m);
	if (n == 1) return puts("1") * 0;
	for (int i = fac[0] = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
	rev[n] = modpow(fac[n], mod - 2);
	for (int i = n; i > 0; i--) rev[i - 1] = rev[i] * i % mod;
	for (int i = 1; i <= n; i++) scanf("%d", sz + i);
	vector<ll> vec = divide(1, n);
	memset(tc, 0, sizeof(tc));
	for (int i = 0; i <= n; i++) tc[i] = vec[i];
	int len = 1;
	while (len <= n) len <<= 1;
	get_ln(tc, sum, len);
	for (int i = 1; i <= n; i++) sum[i] = mod - sum[i] * i % mod;
	sum[0] = n;
	for (int i = 0; i < n; i++) {
		A[i] = modpow(i + 1, m) * rev[i] % mod;
		B[i] = modpow(i + 1, 2 * m) * rev[i] % mod;
	}
	get_ln(A, tc, len);
	get_inv(A, C, len);
	ntt(C, len << 1, 0), ntt(B, len << 1, 0);
	for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
	ntt(B, len << 1, 1);
	memset(A, 0, sizeof(A));
	memset(C, 0, sizeof(C));
	for (int i = 0; i < n; i++) {
		B[i] = B[i] * sum[i] % mod;
		A[i] = tc[i] * sum[i] % mod;
	}
	get_exp(A, C, len);
	for (int i = n; i < len << 1; i++) B[i] = 0;
	ntt(B, len << 1, 0), ntt(C, len << 1, 0);
	for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
	ntt(B, len << 1, 1);
	ll res = B[n - 2] * fac[n - 2] % mod;
	for (int i = 1; i <= n; i++) res = res * sz[i] % mod;
	printf("%lld\n", res);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/WAautomaton/article/details/85017586