刚学的FFT。。证明好玄乎啊
根据mjs大佬的原话,FFT这种东西不需要理解,背了模板就好
先贴题
题意:从n个数中选出1,2或3个数求和,询问组成每个和的方案数。
思路:生成函数+FFT+容斥原理
假设可选的数表为 ,那么构造多项式 。怎么解释呢?将 表示为若干个单项式 之和,其中 为原本数表中的数, 为 在表中出现的次数。
那答案不就是 各项的系数与次数吗?
并不是。。将 展开,会发现某项会自乘,从而对答案产生贡献。
这时候我们就要请出 客厅原理 容斥原理了
再构造多项式
同
的原理,但
为原本数表中数的两倍,
中为三倍。
这样
与
即可表示将某个数限制选两次(三次)的方案数了
于是由容斥原理我们得到
为最终的答案
于是我们就可以用FFT在 的时间内求出多项式 啦
上代码
#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
struct CComplexNumber{
double r,i;
CComplexNumber():r(0.),i(0.){}
CComplexNumber(const double& _r,const double& _i):r(_r),i(_i){}
friend CComplexNumber operator+(const CComplexNumber& c1,const CComplexNumber& c2){
return CComplexNumber(c1.r+c2.r,c1.i+c2.i);
}
friend CComplexNumber operator-(const CComplexNumber& c1,const CComplexNumber& c2){
return CComplexNumber(c1.r-c2.r,c1.i-c2.i);
}
friend CComplexNumber operator*(const CComplexNumber& c1,const CComplexNumber& c2){
return CComplexNumber(c1.r*c2.r-c1.i*c2.i,c1.r*c2.i+c1.i*c2.r);
}
friend CComplexNumber operator*(const double& x,const CComplexNumber& c){
return CComplexNumber(x*c.r,x*c.i);
}
friend CComplexNumber operator/(const CComplexNumber& c,const double& x){
return CComplexNumber(c.r/x,c.i/x);
}
CComplexNumber operator*=(const CComplexNumber& c){
CComplexNumber c0(r*c.r-i*c.i,r*c.i+i*c.r);
r=c0.r;
i=c0.i;
return *this;
}
CComplexNumber operator/=(const double& x){
r/=x;
i/=x;
return *this;
}
};
const double Pi=acos(-1);
CComplexNumber x[360001],y[360001],z[360001];
int len=1,N,n,rev[360001],a,bit;
void fft(CComplexNumber *f,int check){
for(int i=0;i<len;++i)if(i<rev[i])swap(f[i],f[rev[i]]);
for(int i=1;i<len;i<<=1){
CComplexNumber wn(cos(Pi/i),check*sin(Pi/i));
for(int j=0;j<len;j+=(i<<1)){
CComplexNumber w(1.,0.);
for(int k=0;k<i;++k){
CComplexNumber x=f[j+k],y=w*f[i+j+k];
f[j+k]=x+y;
f[i+j+k]=x-y;
w*=wn;
}
}
}
if(check==-1)for(int i=0;i<len;++i)f[i]/=(double)len;
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;++i){
scanf("%d",&a);
++x[a].r;
++y[2*a].r;
++z[3*a].r;
N=max(N,a);
}
while(len<=6*N+1){
len*=2;
++bit;
}
for(int i=0;i<len;++i)rev[i]=((rev[i>>1]>>1)|((1&i)<<bit-1));
fft(x,1);
fft(y,1);
fft(z,1);
for(int i=0;i<len;++i)
x[i]=(x[i]*x[i]*x[i]-3*x[i]*y[i]+2*z[i])/6.+(x[i]*x[i]-y[i])/2.+x[i];
fft(x,-1);
for(int i=0;i<len;++i)if(round(x[i].r))printf("%d %d\n",i,(int)round(x[i].r));
}