模板链接:https://www.luogu.org/problemnew/show/P3803
【模板】多项式乘法(FFT)
题目背景
这是一道FFT模板题
注意:虽然本题开到3s,但是建议程序在1s内可以跑完,本题需要一定程度的常数优化。
题目描述
给定一个
n
次多项式
F(x)
,和一个
m
次多项式
G(x)
。
请求出
F(x)
和
G(x)
的卷积。
输入输出格式
输入格式:
第一行
2
个正整数
n,m
。
接下来一行
n+1
个数字,从低到高表示
F(x)
的系数。
接下来一行
m+1
个数字,从低到高表示
G(x)
的系数。
输出格式:
一行
n+m+1
个数字,从低到高表示
F(x)∗G(x)
的系数。
输入输出样例
输入样例#1:
1 2
1 2
1 2 1
输出样例#1:
1 4 5 2
说明
保证输入中的系数大于等于
0
且小于等于
9
。
对于
100%
的数据:
n,m≤106
, 共计20个数据点,2s。
数据有一定梯度。
空间限制:256MB
题解
FFT
学习,从入门到入土。。。
边写博客边加深一下理解吧。
1.复数
复数是实数和虚数的结合体,众所周知,实数可以表示在横向的数轴上,而虚数可以纵轴上,这样实数与虚数的两个坐标轴就构成了一个平面,一个复数的代数表示就是
a+bi
,
a
为实部,
bi
为虚部,放在坐标轴中就是平面上的一个点了。
上图即为复数
3+2i
在复平面上的表示,显然,这样的表示是唯一对应的。
复数的运算法则:
加法:
(a+bi)+(c+di)=(a+c)+(b+d)i
乘法:
(a+bi)×(c+di)=(ac−bd)+(ad+bc)i
对于加法,复数的运算与向量类似,但向量的运算法则不能很好的用在复数的乘法上,所以我们考虑用极坐标将一个复数表示为
(r,θ)
,在这种表示下我们再来看复数乘法:
(r1cosθ1+r1sinθ1i)×(r2cosθ2+r2sinθ2i)=(r1r2cosθ1cosθ2−r1r2sinθ1sinθ2)+(r1r2cosθ1sinθ2+r1r2sinθ1cosθ2)i=(r1r2cos(θ1+θ2))+(r1r2sin(θ1+θ2))
新得到的复数的坐标就为
(r1r2,θ1+θ2)
,可以说是非常简洁而优美了。
2.单位根
单位根就是方程
xn=1
在复数域的解。
单位根有什么特点呢?
假设一个单位根为
(1,θ)
,它的
n
次方就是
(1,n×θ)
,
1
在坐标轴上的坐标为
(1,0)
,那么就有
n×θ=2kπ(k∈Z)
,解得
θ=2kπn(k∈Z)
,所以
1
的
n
次平方根共有
n
个,表示在复平面上就是
n
个将单位圆等分为
n
份的点,且其中一个为
1
。
我们将满足
ωn=1
的
ω
称为
1
的
n
次单位根,记做
ωn
,并且用
0∼n−1
标号,记做
ω0n∼ωn−1n
,下图为
n=8
时的单位根图像
注意到
ωkn=ωk+nn,ωkn=ωkxnx
。
3.多项式的表示法
一个
n−1
次多项式的基本形式为:
f(x)=∑i=0naixi
我们称这种表示法为系数表示法,有了一个多项式的系数表达式,我们就能快速计算出代入的自变量
x
对应的函数值。
但是借助系数表示法无法快速求出两个多项式相乘后的多项式,考虑多项式相乘
F(x)=f(x)×g(x)
的本质事实上是对于自变量
x
的每个可能取值
xi
,都有
F(xi)=f(xi)×g(xi)
,这样看来,如果我们知道了两个多项式对于同一个横坐标
xi
的点值,我们便能快速计算出它们相乘之后得到的多项式的点值。
由
Lagrange
插值法可知,
n
个点可以唯一确定一个
n−1
阶多项式,那么我们就可以用这
n
个点来表示一个多项式,我们称这种表示法为点值表示法。
借助系数表示法,我们可以快速求出函数值;而通过点值表示法,我们可以快速完成多项式相乘的运算。
我们要计算两个多项式的乘积,就有了一个简单的思路:
(1)将两个多项式由系数表示转换为点值表示。
(2)将两个多项式的点值相乘,得到乘积多项式的点值。
(3)将新多项式由点值表示转换为系数表示。
4.快速傅里叶变换
在上述计算多项式相乘的思路中,复杂度瓶颈在于两种表示法之间的“转换”,朴素的实现是将自变量一一代入多项式求出对应的点值,复杂度为
O(n2)
,并不优秀。
快速傅里叶变换(
FFT
)便是突破瓶颈的关键,
FFT
实现了两种快速表示法之间
O(nlog2n)
的快速转换。
设
n
项多项式
f(x)
,其中
n=2k,α=2πn,ωn=cosα+sinαi
,考虑用这
n
个单位根上的点值来表示多项式:
f(ωkn)=∑i=0n−1aiωkin
我们可以分奇偶项将上面的式子拆成两个:
∑i=0n−1aiωkin=∑i=0n2−1a2iω2kin+∑i=0n2−1a2i+1ω(2i+1)kn=∑i=0n2−1a2iω2kin+ωkn(∑i=0n2−1a2i+1ω2kin)=∑i=0n2−1a2iωkin2+ωkn(∑i=0n2−1a2i+1ωkin2)
这样我们就得到了两个形式与
f(x)
相似的多项式,显然可以递归处理下去。同时,为了能够顺利分治到最底层,我们需要将多项式的项数补全到
2k
。
让我们以
8
项的
7
次多项式为例,再仔细研究一下这个过程,设:
F(x)=a0+a1x+a2x2+⋯+a7x7f(x)=a0+a2x2+⋯+a6x6g(x)=a1+a3x2+⋯+a7x6
不难得到:
F(x)=f(x2)+x×g(x2)
将
ωkn(k<n2)
代入这个式子:
F(ωkn)=f(ω2kn)+ωkn×g(ω2kn)
再代入
ωk+n2n
:
F(ωk+n2n)=f(ω2k+nn)+ωk+n2n×g(ω2k+nn)=f(ω2kn)−ωkn×g(ω2kn)
震惊!
F(ωk+n2n)
与
F(ωkn)
只有一字之差!我们在枚举
ωkn(k<n2)
的时候就可以
O(1)
得到
ωk+n2n
的值。
在此基础上,我们又将问题缩小了一半,至此我便可以通过与线段树类似分治思想在
O(nlog2n)
的时间复杂度内将一个多项式从系数表达式转换为点值表达式。
不仅如此,系数的位置变化也有不可告人的秘密:
(a0,a1,a2,a3,a4,a5,a6,a7)↓(a0,a2,a4,a6),(a1,a3,a5,a7)↓(a0,a4),(a2,a6),(a1,a5),(a3,a7)↓(a0),(a4),(a2),(a6),(a1),(a5),(a3),(a7)
似乎没有什么规律?让我们将收尾的二进制拿出来比较:
000,001,010,011,100,101,110,111000,100,010,110,001,101,011,111
规律的气息:我们发现最后的序列的二进制等于初始序列二进制的翻转!
这个规律的本质便是我们每次都将系数按照奇偶分开,我们先将最后一位为
0
的放在前面,再将倒数第二位为
0
的数放在前面。。。以此类推,最后数列的比较方式实际上是从低位到高位比较,与初始序列相反,最后就得到了这样一个优美的性质。
这样,我们可以预先知道哪个位置上是哪个数,从底层开始向上合并,通过访问连续的内存使算法大大加快!
考虑如何处理出翻转后的数,我们可以借用前面的数处理出的信息,对于一个在二进制数有
L
位的数
x
,它的前
L−1
位翻转后与
x>>1
相同,所以我们只需要保留
x
的最后一位,将其放到第一位,然后把
x>>1
的翻转结果拿过来就好了:
for(mx=1;mx<=n+m;mx<<=1,++len);
for(int i=0;i<mx;++i)rev[i]=(rev[i>>1]>>1|((i&1)<<(len-1)));
5.快速傅里叶反变换
考虑将代入自变量计算点值的过程表示成矩阵乘法:
先贴一发矩阵乘法运算法则:
设矩阵
A
为
m×p
的矩阵,
B
为
p×n
的矩阵,
A×B=C
,那么有:
Ci,j=∑k=1pAi,k×Bk,j
于是,我们可以将上述过程表示为如下的矩阵乘法,设
n
阶矩阵
W
满足
Wi,j=ωijn
,以及长度为
n
的向量
A
表示
f(x)
的各个系数
ai
,就有
W×A={f(ωin)|i∈[0,n−1]}
:
⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢1111⋮11w1nw2nw3n⋮wn−1n1w2nw4nw6n⋮w2(n−1)n1w3nw6nw9n⋮w3(n−1)n⋯⋯⋯⋯⋱⋯1wn−1nw2(n−1)nw3(n−1)n⋮w(n−1)2n⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢a0a1a2a3⋮an−1⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥=⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢f(ω0n)f(ω1n)f(ω2n)f(ω3n)⋮f(ωn−1n)⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥
那我们实际上要做的就是构造一个
W
的“倒数”即逆矩阵来抵消
W
,考虑
n
阶矩阵
W′
,满足
W′i,j=ω−ijn
,就有
WW′=n×In
,其中
In
为
n
阶单位矩阵,那么
W
的逆矩阵实际上就是
1nW′
,所以我们只需要在正向变换上稍作改动,最后算出来的系数再除以项数就能完成从点值表示到系数表示的转换。
代码
其实
STL
有自带的复数类
(complex)
,运行效率你们懂得。。。
#include<bits/stdc++.h>
#define db double
using namespace std;
const int M=4e6+5;
const db pi=acos(-1.0);
struct cpx{db x,y;}f[M],g[M];
cpx operator +(cpx a,cpx b){return (cpx){a.x+b.x,a.y+b.y};}
cpx operator -(cpx a,cpx b){return (cpx){a.x-b.x,a.y-b.y};}
cpx operator *(cpx a,cpx b){return (cpx){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};}
int n,m,mx,len,rev[M];
void in(){scanf("%d%d",&n,&m);for(int i=0;i<=n;++i)scanf("%lf",&f[i].x);for(int i=0;i<=m;++i)scanf("%lf",&g[i].x);}
void fft(cpx *f,int typ)
{
cpx wn,w,x,y;int i,mid,j,k;
for(i=0;i<mx;++i)if(i<rev[i])swap(f[i],f[rev[i]]);
for(mid=1;mid<mx;mid<<=1)for(j=0,wn=(cpx){cos(pi/mid),typ*sin(pi/mid)};j<mx;j+=mid<<1)
for(k=0,w=(cpx){1,0};k<mid;++k,w=w*wn)x=f[j+k],y=w*f[j+mid+k],f[j+k]=x+y,f[j+mid+k]=x-y;
}
void ac()
{
for(mx=1;mx<=n+m;mx<<=1,++len);
for(int i=0;i<mx;++i)rev[i]=(rev[i>>1]>>1|((i&1)<<(len-1)));
fft(f,1);fft(g,1);
for(int i=0;i<=mx;++i)f[i]=f[i]*g[i];
fft(f,-1);
for(int i=0;i<=n+m;++i)printf("%d ",int(f[i].x/mx+0.5));
}
int main(){in();ac();}