选数 - 容斥 - 分块 - FWT

题目大意:
有n个数字,要求选出k个不同的数字使得异或和是s,对所有选择方案求gcd并求和。 n 1 0 6 , a i , s m 50000 n\le10^6,a_i,s\le m\le50000
题解:
首先关于gcd可以容斥成 i = 1 n f ( i ) ϕ ( i ) \sum_{i=1}^n f(i)\phi(i) f ( i ) f(i) 表示选k个不同的数字使得gcd是i的倍数的方案数。
然后发现这个可以容斥成可以选相同的数字。然后考虑若 i < S i<S ,则直接FWT;否则跑 ( m i ) 2 \left(\frac{m}{i}\right)^2 暴力。然后积一波分求复杂度发现取 S = m S=\sqrt{m} 最优。

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 998244353
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
namespace INPUT_SPACE{
    const int BS=(1<<24)+5;char Buffer[BS],*HD,*TL;inline int gc() { if(HD==TL) TL=(HD=Buffer)+fread(Buffer,1,BS,stdin);return (HD==TL)?EOF:*HD++; }
    inline int inn() { int x,ch;while((ch=gc())<'0'||ch>'9');x=ch^'0';while((ch=gc())>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^'0');return x; }
}using INPUT_SPACE::inn;
const int N=66000;
int cnt[N],bsz[N],a[N],lst[N],vis[N],L[N],R[N],sn,inv6,inv24;
inline int clr(int *a,int n) { return memset(a,0,sizeof(int)*n),0; }
inline int gcd(int a,int b) { return a?gcd(b%a,a):b; }
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int solvek1(int s) { return !printf("%lld\n",(lint)cnt[s]*s%mod); }
inline int solvek2(int n,int s) { int ans=0;rep(i,1,n) ans=(ans+(lint)cnt[i]*cnt[i^s]%mod*gcd(i,i^s))%mod;return !printf("%lld\n",(lint)ans*fast_pow(2,mod-2)%mod); }
int p[N],phi[N],np[N];
inline int prelude(int n)
{
    phi[1]=1,sn=(int)sqrt(n/7+0.5);
    for(int i=2,c=0;i<=n;i++)
    {
        if(!np[i]) p[++c]=i,phi[i]=i-1;
        rep(j,1,c&&p[j]<=n/i)
        {
            int x=p[j]*i;np[x]=1;
            if(i%p[j]) phi[x]=phi[i]*(p[j]-1);
            else { phi[x]=phi[i]*p[j];break; }
        }
    }
    int m=1;while(m<=n) m<<=1;rep(i,1,m-1) bsz[i]=bsz[i>>1]^(i&1);
    return 0;
}
inline int fwt(int *a,int n)
{
    for(int i=2;i<=n;i<<=1) for(int j=0,t=i>>1,x,y;j<n;j+=i) rep(k,0,t-1)
        x=a[j+k],y=a[j+k+t],a[j+k]=(x+y>=mod?x+y-mod:x+y),a[j+k+t]=(x-y<0?x-y+mod:x-y);
    return 0;
}
inline int ufwt(int *a,int n,int s)
{
    lint ans=0;rep(i,0,n-1) if(bsz[i&s]) ans-=a[i];else ans+=a[i];
    return ans%=mod,ans+=mod,ans%=mod,int(ans*fast_pow(n,mod-2)%mod);
}
inline int F(int x,int n,int k,int s)
{
    int m=1;while(m<=n) m<<=1;
    if(x<=sn)
    {
        clr(a,m);rep(i,1,n/x) a[i*x]=cnt[i*x];fwt(a,m);
        rep(i,0,m-1) a[i]=fast_pow(a[i],k);return ufwt(a,m,s);
    }
    int c=0,ans=0;
    for(int i=x;i<=n;i+=x) for(int j=x;j<=n;j+=x)
    {
        if(!vis[i^j]) vis[lst[++c]=i^j]=1;
        L[i^j]=R[i^j]=(L[i^j]+(lint)cnt[i]*cnt[j])%mod;
    }
    if(k==3) { rep(i,1,c) R[lst[i]]=0;for(int i=x;i<=n;i+=x) R[i]=cnt[i]; }
    rep(i,1,c) if(R[lst[i]^s]) ans=(ans+(lint)L[lst[i]]*R[lst[i]^s])%mod;
    rep(i,1,c) vis[lst[i]]=L[lst[i]]=R[lst[i]]=0;
    rep(i,0,n/x) L[i*x]=R[i*x]=vis[i*x]=0;return ans;
}
inline int solve(int x,int n,int k,int s)
{
    int m=0;rep(i,1,n/x) m+=cnt[i*x];
    if(k==1) return s%x==0?cnt[s]:0;
    else if(k==2) { int ans=0;for(int i=x;i<=n;i+=x) if((s^i)%x==0) ans=(ans+(lint)cnt[i]*cnt[i^s])%mod;return ans; }
    else if(k==3) { int ans=(F(x,n,3,s)-(3ll*m-3+1)*solve(x,n,1,s))%mod;if(ans<0) ans+=mod;return (lint)ans*inv6%mod; }
    int ans=(F(x,n,4,s)-(4+6ll*(m-2))*solve(x,n,2,s))%mod;if(ans<0) ans+=mod;return (lint)ans*inv24%mod;
}
int main()
{
//  freopen("data.in","r",stdin);
    inv6=fast_pow(6,mod-2),inv24=fast_pow(24,mod-2);
    int m=inn(),k=inn(),s=inn(),n=s,x,ans=0;
    rep(i,1,m) cnt[x=inn()]++,n=max(n,x);prelude(n);
    if(k==1) return solvek1(s);if(k==2) return solvek2(n,s);
    rep(i,1,n) ans=(ans+(lint)phi[i]*solve(i,n,k,s))%mod;
    return !printf("%d\n",ans);
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/89335002
FWT