Luogu5860 「SWTR-03」Counting Trees

Link

Solution

一棵\(n\)个点的树应满足度数和\(\sum a_i=2n-2\)。而且只要度数和满足条件就一定能根据构造出一棵这样的树。

转化成求有多少种方案满足\(\sum\limits_{i=1}^ma_i=2m-2\)\(m\)任意。移下项就是:\(\sum\limits_{i=1}^m(a_i-2)=-2\)

\(n^2\)的01背包做法就显然了,设\(f_{i,j}\)表示前\(i\)个点权值为\(j\)的方案,转移\(f_{i,j}=f_{i-1,j}+f_{i,j-v_i}\)

然后考虑优化转移速度,有一个转到OGF上的妙招:

如果有n个物品,第\(i\)个物品有价值\(a_i\),问总价值为\(S\)的选择方案。

\(F(x)=\prod\limits_{i=1}^n(1+x^{a_i})\)。那么答案就是\([x^S]F(x)\)

\(F(x)\)\(ln\)\(ln\ F(x)=\sum\limits_{i=1}^nln(1+x^{a_i})\)

泰勒展开\(ln(1+x)=0+\sum\limits_{i=1}^\infty (-1)^{i-1}\frac{x^i}{i}\)

代入\(x=x^{a_i}\)\(ln(1+x^{a_i})=\sum\limits_{j=1}^{\lfloor \frac{n}{a_i}\rfloor}(-1)^{j-1}\frac{x^{a_i\times j}}{j}\)

所以在\(a_i\)的所有倍数位置加上贡献,最后\(exp\)回去即可。

实现的时候好呆啊!一个一个枚举\(a_i\),以为是\(\text O(nlogn)\)的,但是只要出题人不用脚造数据,只要出一堆度数小的,分分钟卡成\(\text O(n^2)\)!所以正确的枚举姿势应该是:开桶计数。枚举度数,度数相同的一起加贡献。这样才是严格\(\text O(nlogn)\)的。

Code

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
inline int read(){
    register int x=0,f=1;register char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=0;ch=getchar();}
    while(isdigit(ch)){x=x*10+(ch^'0');ch=getchar();}
    return f?x:-x;
}

const int N=1.1e6,mod=998244353;
inline int power(int base,int n){int ans=1;for(;n;n>>=1,base=1ll*base*base%mod)if(n&1)ans=1ll*ans*base%mod;return ans;}
int n,a[N],f[N],inv[N],cnt[N];

namespace Poly{
    int trans[N];
    inline void NTT(int *a,int n,int type){
        static unsigned long long f[N];
        for(int i=0;i<n;++i)f[i]=a[i];
        for(int i=0;i<n;++i)if(i<trans[i])swap(f[i],f[trans[i]]);
        for(int len=2;len<=n;len<<=1){
            int e=power(3,(mod-1)/len),d=len>>1;if(type)e=power(e,mod-2);
            for(int p=0;p<n;p+=len)
                for(int i=p,nw=1;i<p+d;++i,nw=1ll*nw*e%mod){
                    int t=1ull*f[i+d]*nw%mod;
                    f[i+d]=f[i]+mod-t;f[i]+=t;
                }
        }
        for(int i=0;i<n;++i)a[i]=f[i]%mod;
    }
    inline void Times(int *f,int *a,int m1,int m2,int lim){
        static int g[N];
        int n=1;for(;n<m1+m2;n<<=1);
        for(int i=0;i<n;++i)trans[i]=(trans[i>>1]>>1)|(i&1?(n>>1):0);
        for(int i=0;i<m2;++i)g[i]=a[i];
        for(int i=m2;i<n;++i)g[i]=0;
        NTT(f,n,0),NTT(g,n,0);
        for(int i=0;i<n;++i)f[i]=1ll*f[i]*g[i]%mod;
        NTT(f,n,1);int inv=power(n,mod-2);
        for(int i=0;i<lim;++i)f[i]=1ll*f[i]*inv%mod;
        for(int i=lim;i<n;++i)f[i]=0;
    }
    inline void Inv(int *f,int m){
        static int g[N],p[N];
        int n=1;for(;n<m;n<<=1);
        g[0]=power(f[0],mod-2);
        for(int len=2;len<=n;len<<=1){
            for(int i=0;i<(len>>1);++i){p[i]=g[i]<<1;if(p[i]>=mod)p[i]-=mod;}
            Times(g,g,len>>1,len>>1,len),Times(g,f,len,len,len);
            for(int i=0;i<len;++i){g[i]=p[i]-g[i];if(g[i]<0)g[i]+=mod;}
        }
        for(int i=0;i<m;++i)f[i]=g[i];
        for(int i=0;i<n;++i)g[i]=p[i]=0;
    }
    inline void Drv(int *f,int m){
        for(int i=1;i<=m;++i)f[i-1]=1ll*f[i]*i%mod;
        f[m]=0;
    }
    inline void Itg(int *f,int m){
        for(int i=m;i;--i)f[i]=1ll*f[i-1]*power(i,mod-2)%mod;
        f[0]=0;
    }
    inline void Ln(int *f,int m){
        static int g[N];
        for(int i=0;i<m;++i)g[i]=f[i];
        Drv(g,m);Inv(f,m);
        Times(f,g,m,m,m);Itg(f,m);
        for(int i=0;i<m;++i)g[i]=0;
    }
    inline void Exp(int *f,int m){
        static int g[N],p[N];
        int n=1;for(;n<m;n<<=1);
        g[0]=1;
        for(int len=2;len<=n;len<<=1){
            for(int i=0;i<(len>>1);++i)p[i]=g[i];
            for(int i=(len>>1);i<len;++i)p[i]=0;
            Ln(p,len);
            for(int i=0;i<len;++i){p[i]=f[i]-p[i];if(p[i]<0)p[i]+=mod;}
            p[0]+=1;if(p[0]>=mod)p[0]-=mod;
            Times(g,p,len>>1,len,len);
        }
        for(int i=0;i<m;++i)f[i]=g[i];
        for(int i=0;i<n;++i)g[i]=p[i]=0;
    }
}using namespace Poly;

int main(){
    n=read();
    for(int i=1;i<=n;++i)++cnt[read()];
    inv[1]=1;
    for(int i=2;i<=cnt[1];++i)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
    for(int i=3;i<=n;++i){
        int a=i-2;
        for(int j=1;a*j<=cnt[1];++j){
            int res=j&1?1:mod-1;res=1ll*res*inv[j]%mod*cnt[i]%mod;
            f[a*j]+=res;if(f[a*j]>=mod)f[a*j]-=mod;
        }
    }
    Exp(f,cnt[1]);
    int ans=0,res=1ll*cnt[1]*(cnt[1]-1)%mod*inv[2]%mod*power(2,cnt[2])%mod;
    for(int i=2;i<=cnt[1];++i){
        ans+=1ll*res*f[i-2]%mod;if(ans>=mod)ans-=mod;
        res=1ll*res*(cnt[1]-i)%mod*inv[i+1]%mod;
    }
    printf("%d\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/fruitea/p/12117861.html