推荐阅读资料:算法导论第30章
本文不做证明,详细证明请看如上资料。
FFT在算法竞赛中主要用来加速多项式的乘法
普通是多项式乘法时间复杂度的是O(n2),而用FFT求多项式的乘法可以使时间复杂度达到O(nlogn)
FFT求多项式的乘法步骤主要如下图
其中求值是将系数表达转换成点值表达,带入的自变量是wn=1的复数解,称为DFT
插值是将点值表达转换成系数表达,称为DFT-1
DFT 和 DFT-1都可以用FFT加速实现
这是递归版的FFT
还有一种非递归的版本
我们发现叶子节点的下表的二进制为:000 100 010 110 001 101 110 111
与它们的本身所对应的位置的二进制:000 001 010 011 100 101 011 111
相反
所以我们可以确定叶子节点的值,从下往上进行操作
求二进制反转的代码(其中L是二进制位):
for (int i = 0; i < n; i++) { R[i] = (R[i>>1]>>1) | ((i&1) << L-1); }
假设现在R[i]的二进制是abcd,没有操作之前的R[i>>1]是0abc,操作之后的是cba0,再右移是0cba,再判断原来的d是不是1在最高位放1或0,就刚好是反转的结果
模板:
递归版(以求大数乘法为例):
#include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1.0) #define LL long long #define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pii pair<int, int> #define piii pair<int,pii> #define mem(a, b) memset(a, b, sizeof(a)) #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout); //head typedef complex<double> cd; const int N = 2e5 + 5; char a[N], b[N]; cd A[N], B[N]; int tmp[N]; void fft(cd *x, int n, int type) { if(n == 1) return ; cd l[n>>1], r[n>>1]; for (int i = 0; i < n; i += 2) { l[i>>1] = x[i]; r[i>>1] = x[i+1]; } fft(l, n>>1, type); fft(r, n>>1, type); cd wn(cos(2*pi/n), sin(type*2*pi/n)), w(1, 0), t; for(int i = 0; i < n>>1; i++, w *= wn) { t = w*r[i]; x[i] = l[i] + t; x[i+(n>>1)] = l[i] - t; } } int main() { while(~scanf("%s%s", a, b)) { int n = strlen(a), m = strlen(b); mem(A, 0); mem(B, 0); mem(tmp, 0); for (int i = n - 1; i >= 0; i--) A[n-1-i] = a[i] - '0'; for (int i = m - 1; i >= 0; i--) B[m-1-i] = b[i] - '0'; m = m + n; for(n = 1; n <= m; n <<= 1); fft(A, n, 1); fft(B, n, 1); for (int i = 0; i < n; i++) A[i] *= B[i]; fft(A, n, -1); for (int i = 0; i < m; i++) { int t = (int)(A[i].real()/n + 0.5); t += tmp[i]; tmp[i] = t%10; tmp[i+1] += t/10; } int i; for (i = m; i >= 1; i--) if(tmp[i]) break; for (i; i >= 0; i--) printf("%d", tmp[i]); printf("\n"); } return 0; }
非递归版(以求大数乘法为例):
#include<bits/stdc++.h> using namespace std; #define fi first #define se second #define pi acos(-1.0) #define LL long long #define mp make_pair #define pb push_back #define ls rt<<1, l, m #define rs rt<<1|1, m+1, r #define ULL unsigned LL #define pll pair<LL, LL> #define pii pair<int, int> #define piii pair<int,pii> #define mem(a, b) memset(a, b, sizeof(a)) #define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define fopen freopen("in.txt", "r", stdin);freopen("out.txt", "w", stout); //head typedef complex<double> cd; const int N = 2e5 + 5; char a[N], b[N]; cd A[N], B[N]; int tmp[N], R[N]; void fft(cd *x, int n, int type) { for (int i = 0; i < n; i++) if(i < R[i]) swap(x[i], x[R[i]]); for (int i = 1; i < n; i <<= 1) { cd wn(cos(pi/i), type*sin(pi/i)); for (int j = 0; j < n; j += i<<1) { cd w(1, 0); for (int k = 0; k < i; k++, w*=wn) { cd X = x[j+k], Y = w*x[j+k+i]; x[j+k] = X+Y; x[j+k+i] = X-Y; } } } } int main() { while(~scanf("%s%s", a, b)) { int n = strlen(a), m = strlen(b), L = 0; mem(A, 0); mem(B, 0); mem(tmp, 0); mem(R, 0); for (int i = n - 1; i >= 0; i--) A[n-1-i] = a[i] - '0'; for (int i = m - 1; i >= 0; i--) B[m-1-i] = b[i] - '0'; m = m + n; for(n = 1; n <= m; n <<= 1) L++; for (int i = 0; i < n; i++) { R[i] = (R[i>>1]>>1) | ((i&1) << L-1); } fft(A, n, 1); fft(B, n, 1); for (int i = 0; i < n; i++) A[i] *= B[i]; fft(A, n, -1); for (int i = 0; i < m; i++) { int t = (int)(A[i].real()/n + 0.5); t += tmp[i]; tmp[i] = t%10; tmp[i+1] += t/10; } int i; for (i = m; i >= 1; i--) if(tmp[i]) break; for (i; i >= 0; i--) printf("%d", tmp[i]); printf("\n"); } return 0; }
PS:手写complex类+非递归版最快