题意:
给定
个连通块,每个连通块的大小为
,接下来依次连
条边,得到的树
的价值定义为:
其中,
表示与第
个连通块连接的边的条数。请求出所有不同连边方式产生的树的价值和膜
.
前置技能:求数列 次方和。
给定
,对于任意的
,求出
。
考虑答案的生成函数
.
直接计算仍然是不行的,注意到
因此考虑先计算
,则
。化简
:
括号内的东西分治NTT即可,然后多项式求ln再求导,即可得到
。
过程
对于每个终方案
,对答案的贡献为
。
由于出现了度数,我们考虑使用prufer序列化简算式。总贡献等价于:
前面的
是常量,我们不需要关注。考虑后面的东西,它等价于:
考虑构建上式关于
的生成函数
。考虑下述的两个多项式:
则有:
也就是说求出
后,需要对于每一项乘上
,这正是我们前面说过可以在
的时间内求出的东西。因此最后的复杂度为
。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 65540, mod = 998244353;
ll ta[maxn], tb[maxn], tc[maxn];
ll modpow(ll a, int b) {
ll res = 1;
for (; b; b >>= 1) {
if (b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(ll *a, int n) {
for (int i = 1, j = n >> 1; i < n - 1; i++) {
if (i < j) swap(a[i], a[j]);
int k = n >> 1;
for (; j >= k; k >>= 1) j -= k;
if (j < k) j += k;
}
}
void ntt(ll *a, int n, int rev) {
rader(a, n);
for (int h = 2; h <= n; h <<= 1) {
ll wn = modpow(3, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
int hh = h >> 1;
for (int i = 0; i < n; i += h)
for (int j = i, w = 1; j < i + hh; j++, w = w * wn % mod) {
const int x = a[j], y = a[j + hh] * w % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
}
}
if (rev) {
int inv = modpow(n, mod - 2);
for (int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
void get_inv(ll *a, ll *b, int n) {
if (n == 1) { b[0] = modpow(a[0], mod - 2); b[1] = 0; return; }
get_inv(a, b, n >> 1);
int m = n << 1;
for (int i = n; i < m; i++) ta[i] = b[i] = 0;
for (int i = 0; i < n; i++) ta[i] = a[i];
ntt(ta, m, 0), ntt(b, m, 0);
for (int i = 0; i < m; i++) b[i] = (mod + 2 - ta[i] * b[i] % mod) * b[i] % mod;
ntt(b, m, 1);
for (int i = n; i < m; i++) b[i] = 0;
}
void get_ln(ll *a, ll *b, int n) {
get_inv(a, tb, n);
for (int i = 1; i < n; i++) b[i - 1] = a[i] * i % mod;
b[n - 1] = 0;
int m = n << 1;
for (int i = n; i < m; i++) b[i] = 0;
ntt(b, m, 0), ntt(tb, m, 0);
for (int i = 0; i < m; i++) b[i] = b[i] * tb[i] % mod;
ntt(b, m, 1);
for (int i = n; i < m; i++) b[i] = 0;
for (int i = n - 1; i > 0; i--) b[i] = b[i - 1] * modpow(i, mod - 2) % mod;
b[0] = 0;
}
void get_exp(ll *a, ll *b, int n) {
if (n == 1) { b[0] = 1, b[1] = 0; return; }
get_exp(a, b, n >> 1);
get_ln(b, tc, n);
int m = n << 1;
for (int i = n; i < m; i++) b[i] = ta[i] = 0;
for (int i = 0; i < n; i++) ta[i] = (mod + !i + a[i] - tc[i]) % mod;
ntt(b, m, 0), ntt(ta, m, 0);
for (int i = 0; i < m; i++) b[i] = b[i] * ta[i] % mod;
ntt(b, m, 1);
for (int i = n; i < m; i++) b[i] = 0;
}
ll fac[maxn], rev[maxn], A[maxn], B[maxn], C[maxn], sum[maxn];
int sz[maxn], n, m;
vector<ll> divide(int l, int r) {
if (l == r) { vector<ll> vec; vec.push_back(1); vec.push_back(mod - sz[l]); return vec; }
int mid = (l + r) >> 1, len = 1;
vector<ll> vl = divide(l, mid), vr = divide(mid + 1, r);
while (len <= r - l) len <<= 1; len <<= 1;
for (int i = 0; i < len; i++) {
ta[i] = tb[i] = 0;
if (i <= mid - l + 1) ta[i] = vl[i];
if (i <= r - mid) tb[i] = vr[i];
}
ntt(ta, len, 0), ntt(tb, len, 0);
for (int i = 0; i < len; i++) ta[i] = ta[i] * tb[i] % mod;
ntt(ta, len, 1);
vector<ll> res;
for (int i = 0; i <= r - l + 1; i++) res.push_back(ta[i]);
return res;
}
int main() {
scanf("%d%d", &n, &m);
if (n == 1) return puts("1") * 0;
for (int i = fac[0] = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
rev[n] = modpow(fac[n], mod - 2);
for (int i = n; i > 0; i--) rev[i - 1] = rev[i] * i % mod;
for (int i = 1; i <= n; i++) scanf("%d", sz + i);
vector<ll> vec = divide(1, n);
memset(tc, 0, sizeof(tc));
for (int i = 0; i <= n; i++) tc[i] = vec[i];
int len = 1;
while (len <= n) len <<= 1;
get_ln(tc, sum, len);
for (int i = 1; i <= n; i++) sum[i] = mod - sum[i] * i % mod;
sum[0] = n;
for (int i = 0; i < n; i++) {
A[i] = modpow(i + 1, m) * rev[i] % mod;
B[i] = modpow(i + 1, 2 * m) * rev[i] % mod;
}
get_ln(A, tc, len);
get_inv(A, C, len);
ntt(C, len << 1, 0), ntt(B, len << 1, 0);
for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
ntt(B, len << 1, 1);
memset(A, 0, sizeof(A));
memset(C, 0, sizeof(C));
for (int i = 0; i < n; i++) {
B[i] = B[i] * sum[i] % mod;
A[i] = tc[i] * sum[i] % mod;
}
get_exp(A, C, len);
for (int i = n; i < len << 1; i++) B[i] = 0;
ntt(B, len << 1, 0), ntt(C, len << 1, 0);
for (int i = 0; i < len << 1; i++) B[i] = B[i] * C[i] % mod;
ntt(B, len << 1, 1);
ll res = B[n - 2] * fac[n - 2] % mod;
for (int i = 1; i <= n; i++) res = res * sz[i] % mod;
printf("%lld\n", res);
return 0;
}