3.20 模拟赛

T1 A

题目大意:

T2 B

题目大意:

给定$n$个点的树,每个点初始时点权均为0

每次等概率选择一个没有被选择过的点,将它所在的联通块的所有点的点权$+1$,再将与这个点连接的边全部删除

最终会只剩下$n$个有点权的点,求这些点的点权和的期望值

思路:

考虑点$x$对点$y$的影响

当且仅当$x$为$x \sim y$这条路径上第一个被选的点时才会造成影响,而这对答案期望的贡献为$\frac{1}{dis(x,y)}$

所以答案为$\sum\limits_{i=1}^n \sum\limits_{j=1}^n \frac{1}{dis(x,y)}$

由于是统计路径信息,很容易想到点分治再用多项式合并

但是如果子树合并必须要按照$maxdep$排序,否则如果$maxdep$较大的子树过早出现会使复杂度出现问题

因此采用容斥的写法即可

(学到了欧神关于$NTT$预处理部分的写法

 1 #include<iostream>
 2 #include<cstdio>
 3 #include<cstring>
 4 #include<cstdlib>
 5 #include<cmath>
 6 #include<algorithm>
 7 #include<queue>
 8 #include<vector>
 9 #include<map>
10 #include<set>
11 #define ll long long
12 #define db double
13 #define inf 2139062143
14 #define MAXN 100100
15 #define MOD 998244353
16 #define rep(i,s,t) for(register int i=(s),i##__end=(t);i<=i##__end;++i)
17 #define dwn(i,s,t) for(register int i=(s),i##__end=(t);i>=i##__end;--i)
18 #define ren for(register int i=fst[x];i;i=nxt[i])
19 #define pb(i,x) vec[i].push_back(x)
20 #define pls(a,b) (a+b)%MOD
21 #define mns(a,b) (a-b+MOD)%MOD
22 #define mul(a,b) (1LL*(a)*(b))%MOD
23 using namespace std;
24 inline int read()
25 {
26     int x=0,f=1;char ch=getchar();
27     while(!isdigit(ch)) {if(ch=='-') f=-1;ch=getchar();}
28     while(isdigit(ch)) {x=x*10+ch-'0';ch=getchar();}
29     return x*f;
30 }
31 int n,fst[MAXN],nxt[MAXN<<1],to[MAXN<<1],cnt;
32 int mx[MAXN],sz[MAXN],Mx,Sum,rt,vis[MAXN],num[MAXN],mxd;
33 int pw[30],ipw[30],l2[MAXN<<2],rev[MAXN<<2],f[MAXN<<2],res[MAXN];
34 void add(int u,int v) {nxt[++cnt]=fst[u],fst[u]=cnt,to[cnt]=v;}
35 int q_pow(int bas,int t,int res=1)
36 {
37     for(;t;t>>=1,bas=mul(bas,bas))
38         if(t&1) res=mul(res,bas);return res;
39 }
40 void ntt(int *a,int n,int f)
41 {
42     rep(i,0,n-1) if(i<rev[i]) swap(a[i],a[rev[i]]);
43     for(int i=1;i<n;i<<=1)
44     {
45         int wn= f==1?pw[l2[i]+1]:ipw[l2[i]+1];
46         for(int j=0;j<n;j+=(i<<1))
47         {
48             int w=1,x,y;
49             for(int k=0;k<i;k++,w=mul(w,wn))
50                 x=a[j+k],y=mul(a[j+k+i],w),a[j+k]=pls(x,y),a[j+k+i]=mns(x,y);
51         }
52     }
53     if(f==1) return ;int nv=q_pow(n,MOD-2);
54     rep(i,0,n-1) a[i]=mul(a[i],nv);
55 }
56 void getrt(int x,int pa)
57 {
58     mx[x]=0,sz[x]=1;ren if(to[i]^pa&&!vis[to[i]])
59         {getrt(to[i],x);sz[x]+=sz[to[i]],mx[x]=max(mx[x],sz[to[i]]);}
60     mx[x]=max(mx[x],Sum-sz[x]);if(mx[x]<Mx) Mx=mx[x],rt=x;
61 }
62 void Get(int x,int pa,int dep)
63 {
64     num[dep]++,mxd=max(mxd,dep);ren if(!vis[to[i]]&&to[i]^pa) Get(to[i],x,dep+1);
65 }
66 void calc(int x,int v)
67 {
68     int t=l2[mxd<<1]+1,lmt=1<<t;
69     rep(i,0,lmt-1) f[i]=0,rev[i]=(rev[i>>1]>>1)|((i&1)<<t-1);
70     rep(i,0,mxd) f[i]=num[i],num[i]=0;
71     ntt(f,lmt,1);rep(i,0,lmt-1) f[i]=mul(f[i],f[i]);
72     ntt(f,lmt,-1);rep(i,0,mxd<<1) res[i+1]= v==1?pls(res[i+1],f[i]):mns(res[i+1],f[i]);
73 }
74 void div(int x)
75 {
76     vis[x]=1,mxd=0;Get(x,0,0);calc(x,1);ren if(!vis[to[i]])
77         {mxd=0;Get(to[i],x,1);calc(to[i],-1);Sum=sz[to[i]],Mx=inf;getrt(to[i],x);div(rt);}
78 }
79 int main()
80 {
81     freopen("B.in","r",stdin);
82     freopen("B.out","w",stdout);
83     n=read();int a,b;rep(i,2,n) a=read(),b=read(),add(a,b),add(b,a);
84     rep(i,2,n<<1)
85     {
86         l2[i]=l2[i>>1]+1;
87         if(!pw[l2[i]]) pw[l2[i]]=q_pow(3,(MOD-1)/i),ipw[l2[i]]=q_pow(pw[l2[i]],MOD-2);
88     }
89     Sum=n,Mx=inf;getrt(1,0);div(rt);int ans=0;
90     rep(i,1,n) ans=pls(ans,mul(res[i],q_pow(i,MOD-2)));
91     printf("%d\n",ans);
92 }
View Code

T3 C

猜你喜欢

转载自www.cnblogs.com/yyc-jack-0920/p/10569337.html