Description
有一棵n个点的树,有m个点要被确定为关键点,一种确定的方案合法当且仅当存在一个点,该点到每个关键点的距离不超过k,求方案数。
Solution
考虑以一个点
显然对于每个点这样算的话会算重,于是我们对于每个点再统计一个与它父亲(确定一个根)距离都小于等于
Code
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define fo(i,j,k) for(int i=j;i<=k;++i)
#define fd(i,j,k) for(int i=j;i>=k;--i)
#define rep(i,x) for(int i=ls[x];i;i=nx[i])
#define ll long long
#define mem(a) memset(a,0,sizeof(a))
using namespace std;
const int N=1e5+10,M=2e5+10,mo=998244353;
int to[M],nx[M],ls[N],vl[M],num=0;
void link(int u,int v,int w){
to[++num]=v,nx[num]=ls[u],ls[u]=num;
vl[num]=w;
}
ll jc[N],ny[N];
int fa[N],f[N],F[N],g[N],fat[N],sz[N];
ll pow(ll x,int y){
ll s=1;
while(y){
if(y&1) s=s*x%mo;
y>>=1,x=x*x%mo;
}
return s;
}
ll C(int m,int n){
if(m<n) return 0;
return jc[m]*ny[n]%mo*ny[m-n]%mo;
}
ll K,ans=0;
int n,m;
bool vis[N];
int mn,rt,cn=0;
void getf(int x){
rep(i,x) if(to[i]!=fa[x]) fa[to[i]]=x,getf(to[i]);
}
void getsz(int x,int fr){
sz[x]=1;
rep(i,x){
int v=to[i];
if(v==fr || vis[v]) continue;
getsz(v,x),sz[x]+=sz[v];
}
}
void getrt(int x,int fr,int o){
int mx=0;
rep(i,x){
int v=to[i];
if(v==fr || vis[v]) continue;
getrt(v,x,o),mx=max(mx,sz[v]);
}
mx=max(mx,sz[o]-sz[x]);
if(mn>mx) mn=mx,rt=x;
}
void findrt(int x,int fr){
getsz(x,fr);
mn=sz[x]+1,getrt(x,fr,x);
}
struct node{
int x,f;
ll t;
}d[N];
int tot=0;
void get(int x,int fr,ll t,int ft,int p){
if(t<=K) g[p]++;
d[++tot].x=x,d[tot].t=t,d[tot].f=ft,fat[x]=fr;
rep(i,x){
int v=to[i];
if(v==fr || vis[v]) continue;
get(v,x,t+vl[i],ft,p);
}
}
bool cmpt(node x,node y){
return x.t<y.t;
}
bool cmpf(node x,node y){
return x.f<y.f || (x.f==y.f && x.t<y.t);
}
void calc(int l,int r,int z){
int rr=r;
fo(i,l,rr){
while(r>=l && d[i].t+d[r].t>K) r--;
F[d[i].x]+=z*(r-l+1);
}
}
void dfs(int x,int fr){
tot=0,vis[x]=1;
rep(i,x){
int v=to[i];
if(v==fr || vis[v]) continue;
get(v,x,vl[i],v,fa[v]==x?v:x);
}
F[x]++;
fo(i,1,tot) if(d[i].t<=K) F[x]++,F[d[i].x]++;
sort(d+1,d+tot+1,cmpt);
calc(1,tot,1);
sort(d+1,d+tot+1,cmpf);
int p=1;
fo(i,2,tot)
if(d[i].f!=d[i-1].f) calc(p,i-1,-1),p=i;
calc(p,tot,-1);
if(fa[x]!=fr) g[x]+=F[fa[x]];
fo(i,1,tot){
p=d[i].x;
if(fa[p]!=fr) g[p]+=fat[p]==fa[p]?F[p]:F[fa[p]];
}
fo(i,1,tot) p=d[i].x,f[p]+=F[p],F[p]=0;
f[x]+=F[x],F[x]=0;
rep(i,x){
int v=to[i];
if(v==fr || vis[v]) continue;
findrt(v,x);
dfs(rt,x);
}
}
int main()
{
scanf("%d %d %lld",&n,&m,&K);
fo(i,2,n){
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
link(u,v,w),link(v,u,w);
}
getf(1),findrt(1,0),dfs(rt,0);
jc[0]=1;
fo(i,1,n) jc[i]=jc[i-1]*i%mo;
ny[n]=pow(jc[n],mo-2);
fd(i,n-1,0) ny[i]=ny[i+1]*(i+1)%mo;
fo(i,1,n){
ans=(ans+C(f[i],m))%mo;
if(i>1) ans=(ans-C(g[i],m)+mo)%mo;
}
printf("%lld",ans*jc[m]%mo);
}