Address
The First Step - 转化
简版题意:给定一棵点带权树,求树上所有大小大于
的连通块的第
大值之和。
众所周知,「第
大值」和「值
的排名」可以互相转化。
所以,答案为:
The Second Step - DP状态
如何求出「
的排名为
的连通块个数」?
如果权值互不相同,那么
的排名为
的连通块个数就是「满足
的权值有
个的连通块个数」。
如果权值有相同,那么
的排名为
,就必须保证
的权值个数至少
,但
的权值个数又要小于
。
于是
的排名为
的连通块个数为「
的权值个数
的连通块个数」
「
的权值个数
的方案数」。
设状态
表示
为根的连通块(包含空连通块),
的权值个数为
的方案数。
这样,答案就为:
The Third Step - DP转移
关于连通块计数的树形 DP 是一个经典模型。
边界:
转移时枚举
的子节点
并枚举
的子树包含的连通块大小
:
为向
转移之前的 DP 数组。
最后加上空连通块:
由于第三维的上界只有
的子树大小,所以每对点都在 LCA 处贡献
次,
复杂度
。
据说很多人用这种做法水过并跑得比标解快
The Fourth Step - DP优化
既然没人写正解,我就来水一发
发现
向子树的转移是一个卷积的形式。所以考虑生成函数。
设
为
的生成函数:
我们不妨将
分别取
个值,分别为
到
代入
,把
转成点值,那么在最外层枚举
,转移变成:
为了统计答案,我们还要记
:
也就是:
发现每次转移的第二维不变,这给我们什么启示呢?
可以用线段树合并实现转移!!!!!
对每个节点开一棵动态开点线段树,下标按照
和
第二维排列。
初始时线段树上每个叶子节点的
和
都是
。
每个节点上维护一个标记
,即子树内所有叶子节点的
改成
。
表示初始(不变)的标记为
。
标记合并的方法:
计算
和
时,先把
对应的线段树,区间
加上
,区间
加上
。
也就是
打标记
,
打标记
。
转移时
的每一个下标都要和
一一相乘,考虑线段树合并。
线段树合并时,需要边合并节点边下放标记。
注意,如果我们某一次需要合并
和
的子树,并且
没有子节点(如果
没有子节点则交换
和
),且
上的标记为
,则只需要把
子树内的
乘上
并且
加上
即可,
相当于
的标记第一个数乘
,第二个数乘
,第四个数加
。
最后将
对应线段树全局
并且
,
相当于根节点打标记
。
处理完之后,我们对
所在的线段树进行一遍 DFS ,当标记下传到叶子节点
并且标记为
时就能得出:
我们再回到答案:
如果考虑
的生成函数
,将
从
到
代入生成函数:
对于每个
求出上面生成函数的点值之后,使用拉格朗日插值法就能求出每个
的
复杂度
。
Code
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Rof(i, a, b) for (i = a; i >= b; i--)
#define Tree(u) for (int e = adj[u], v; e; e = nxt[e]) if ((v = go[e]) != fu)
using namespace std;
inline int read()
{
int res = 0; bool bo = 0; char c;
while (((c = getchar()) < '0' || c > '9') && c != '-');
if (c == '-') bo = 1; else res = c - 48;
while ((c = getchar()) >= '0' && c <= '9')
res = (res << 3) + (res << 1) + (c - 48);
return bo ? ~res + 1 : res;
}
template <class T>
void Swap(T &a, T &b) {a ^= b; b ^= a; a ^= b;}
const int N = 1700, M = N << 1, L = 1e6 + 5, ZZQ = 64123;
int n, K, W, d[N], ecnt, nxt[M], adj[N], go[M], a[N], QAQ,
top, stk[L], b[N], pmt[N], tmp[N], ans[N], inv[ZZQ], Ans;
struct mark
{
int a, b, c, d;
friend inline mark operator + (mark x, mark y)
{
return (mark) {(int) (1u * x.a * y.a % ZZQ),
(int) ((1u * x.b * y.a + y.b) % ZZQ),
(int) ((1u * x.a * y.c + x.c) % ZZQ),
(int) ((1u * x.b * y.c + y.d + x.d) % ZZQ)};
}
};
struct node
{
mark x; int lc, rc;
} T[L];
int newnode()
{
if (top) return
T[stk[top]] = (node) {(mark) {1, 0, 0, 0}, 0, 0}, stk[top--];
return
T[++QAQ] = (node) {(mark) {1, 0, 0, 0}, 0, 0}, QAQ;
}
void add_edge(int u, int v)
{
nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
nxt[++ecnt] = adj[v]; adj[v] = ecnt; go[ecnt] = u;
}
void down(int p)
{
if (!T[p].lc) T[p].lc = newnode();
if (!T[p].rc) T[p].rc = newnode();
T[T[p].lc].x = T[T[p].lc].x + T[p].x;
T[T[p].rc].x = T[T[p].rc].x + T[p].x;
T[p].x = (mark) {1, 0, 0, 0};
}
int mer(int x, int y)
{
if (!x || !y) return x + y;
if (!T[x].lc && !T[x].rc) Swap(x, y);
if (!T[y].lc && !T[y].rc)
return T[x].x.a = (int) (1u * T[x].x.a * T[y].x.b % ZZQ),
T[x].x.b = (int) (1u * T[x].x.b * T[y].x.b % ZZQ),
T[x].x.d = (T[x].x.d + T[y].x.d) % ZZQ,
stk[++top] = y, x;
down(x); down(y);
T[x].lc = mer(T[x].lc, T[y].lc);
T[x].rc = mer(T[x].rc, T[y].rc);
return stk[++top] = y, x;
}
void change(int l, int r, int s, int e, int v, int &p)
{
if (!p) p = newnode();
if (l == s && r == e)
return (void) (T[p].x.b = (T[p].x.b + v) % ZZQ);
int mid = l + r >> 1;
down(p);
if (e <= mid) change(l, mid, s, e, v, T[p].lc);
else if (s >= mid + 1) change(mid + 1, r, s, e, v, T[p].rc);
else change(l, mid, s, mid, v, T[p].lc),
change(mid + 1, r, mid + 1, e, v, T[p].rc);
}
int dfs(int val, int u, int fu)
{
int p = 0;
change(1, W, 1, d[u], val, p);
if (d[u] < W) change(1, W, d[u] + 1, W, 1, p);
Tree(u) p = mer(p, dfs(val, v, u));
T[p].x = T[p].x + (mark) {1, 1, 1, 0};
return p;
}
void orz(int val, int l, int r, int p)
{
if (l == r) return (void) (a[val] = (a[val] + T[p].x.d) % ZZQ);
int mid = l + r >> 1;
down(p);
orz(val, l, mid, T[p].lc);
orz(val, mid + 1, r, T[p].rc);
}
void jiejuediao(int val)
{
QAQ = top = 0;
orz(val, 1, W, dfs(val, 1, 0));
}
void divi(int x)
{
int i;
For (i, 0, n + 1) pmt[i] = b[i];
Rof (i, n + 1, 1)
{
tmp[i - 1] = pmt[i];
pmt[i - 1] = (pmt[i - 1] + (int) (1u * x * pmt[i] % ZZQ)) % ZZQ;
}
}
void Lagrange()
{
int i, j;
b[0] = 1;
For (i, 1, n + 1)
{
For (j, 0, i - 1) tmp[j] = b[j] * i % ZZQ;
Rof (j, i, 1) b[j] = b[j - 1];
b[0] = 0;
For (j, 0, i) b[j] = (b[j] - tmp[j] + ZZQ) % ZZQ;
}
For (i, 1, n + 1)
{
divi(i);
int rq = 1;
For (j, 1, n + 1) if (j != i)
rq = (int) (1u * rq * ((i - j + ZZQ) % ZZQ) % ZZQ);
rq = (int) (1u * a[i] * inv[rq] % ZZQ);
For (j, 0, n)
ans[j] = (ans[j] + (int) (1u * tmp[j] * rq % ZZQ)) % ZZQ;
}
}
int main()
{
int i, x, y;
inv[1] = 1;
For (i, 2, ZZQ - 1)
inv[i] = (int) (1u * (ZZQ - ZZQ / i) * inv[ZZQ % i] % ZZQ);
n = read(); K = read(); W = read();
For (i, 1, n) d[i] = read();
For (i, 1, n - 1) x = read(), y = read(),
add_edge(x, y);
For (i, 1, n + 1) jiejuediao(i);
Lagrange();
For (i, K, n) Ans = (Ans + ans[i]) % ZZQ;
cout << Ans << endl;
return 0;
}