「算法笔记」Fast Fourier Transformation

推荐博客:algocode 算法博客

个人见解: F a s t   F o u r i e r   T r a n s f o r m a t i o n ( F F T ) 是一种能在 O ( n log n ) 时间内求多项式的乘积的算法,代码复杂度较小,缺点是常数巨大。

算法思路:我们发现点值表示法的多项式乘法是 O ( n ) 的,于是我们可以绕过


而是使用
  ( D F T )


  ( I D F T )

来求解多项式乘法。

现在的关键问题就是如何快速进行 D F T I D F T 操作。

Created with Raphaël 2.1.2 多项式乘法 使用暴力? 系数表示法的多项式乘法 O(n^2) Time Limit Exceeded 乘法完成 DFT(用 FFT 优化到 O(n log n)) 点值表示法的多项式乘法 O(n) IDFT(用 FFT 优化到 O(n log n)) yes no

算法流程:见 algocode 算法博客。

// Template
#include <cmath>
#include <cstdio>
#include <complex>
using namespace std;
typedef complex<double> cd;
const int maxn = 262145;
const double pi = acos(-1.);
cd a[maxn], b[maxn];
int n, m, s, bit, rev[maxn];
void prework() {
    for (bit = 1, s = 2; 1 << bit <= n + m; bit++) {
        s <<= 1;
    }
    for (int i = 0; i < 1 << bit; i++) {
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }
}
void dft(cd *a, int n, int op) {
    for (int i = 0; i < n; i++) {
        if (i < rev[i]) {
            swap(a[i], a[rev[i]]);
        }
    }
    for (int step = 1; step < n; step <<= 1) {
        cd wn0 = exp(cd(0, op * pi / step));
        for (int j = 0; j < n; j += step << 1) {
            cd wnk(1, 0);
            for (int k = j; k < j + step; k++) {
                cd x = a[k];
                cd y = wnk * a[k + step];
                a[k] = x + y;
                a[k + step] = x - y;
                wnk *= wn0;
            }
        }
    }
    if (op == -1) {
        for (int i = 0; i < n; i++) {
            a[i] /= n;
        }
    }
}
int main() {
    scanf("%d %d", &n, &m);
    prework();
    for (int x, i = 0; i <= n; i++) {
        scanf("%d", &x);
        a[i] = cd(x, 0);
    }
    for (int x, i = 0; i <= m; i++) {
        scanf("%d", &x);
        b[i] = cd(x, 0);
    }
    dft(a, s, 1), dft(b, s, 1);
    for (int i = 0; i < s; i++) {
        a[i] *= b[i];
    }
    dft(a, s, -1);  for (int i = 0; i <= n + m; i++) {
        printf("%d%c", int(a[i].real() + 0.5), " \n"[i == n + m]);
    }
    return 0;
}

例题一:【BZOJ 2194】快速傅立叶之二

这题几乎是一道模版题。只需将 b 数组反向计算两个多项式的乘法即可。

#include <cmath>
#include <cstdio>
#include <complex>
#include <algorithm>
using namespace std;
const int maxn = 1 << 18 | 5;
const double pi = acos(-1.);
typedef complex<double> cd;
int n, m, l, r[maxn];
cd a[maxn], b[maxn];
void fft(cd *a, int op) {
    for (int i = 0; i < m; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < m; k <<= 1) {
        cd wn0(cos(pi / k), op * sin(pi / k));
        for (int i = 0; i < m; i += k << 1) {
            cd wnk = 1;
            for (int j = i; j < i + k; j++, wnk *= wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y;
            }
        }
    }
    if (op == -1) {
        for (int i = 0; i < m; i++) {
            a[i] /= m;
        }
    }
}
int main() {
    scanf("%d", &n);
    for (int x, y, i = 0; i < n; i++) {
        scanf("%d %d", &x, &y);
        a[i] = x, b[i] = y;
    }
    for (m = 1; m < n << 1; m <<= 1) {
        l++;
    }
    for (int i = 0; i < m; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    }
    reverse(a, a + n);
    fft(a, 1);
    fft(b, 1);
    for (int i = 0; i < m; i++) {
        a[i] *= b[i];
    }
    fft(a, -1);
    for (int i = n - 1; ~i; i--) {
        printf("%d\n", int(a[i].real() + 0.5));
    }
    return 0;
}

例题二:【BZ3527】力

推得 E i = j = 1 i 1 q j ( j i ) 2 j = i + 1 n q j ( j i ) 2

f [ i ] = q [ i ] , g [ i ] = 1 i 2 ,则 E i = j = 1 i 1 f [ i ] g [ i j ] j = i + 1 n f [ i ] g [ j i ] 。用 F F T 计算即可。

#include <cmath>
#include <cstdio>
#include <complex>
#include <iostream>
#include <algorithm>
using namespace std;
typedef double db;
typedef complex<db> cd;
const int maxn = 1 << 18 | 5;
const db pi = acos(-1.);
int n, m, l, r[maxn];
db f[maxn], g[maxn], a[maxn], b[maxn];
void dft(cd *a, int op) {
    for (int i = 0; i < m; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < m; k <<= 1) {
        cd wn0(cos(pi / k), op * sin(pi / k));
        for (int i = 0; i < m; i += k << 1) {
            cd wnk(1, 0);
            for (int j = i; j < i + k; j++, wnk *= wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y;
            }
        }
    }
}
void fft(db *a, db *b, db *c) {
    static cd p[maxn], q[maxn];
    for (int i = 0; i < m; i++) {
        p[i] = a[i], q[i] = b[i];
    }
    dft(p, 1), dft(q, 1);
    for (int i = 0; i < m; i++) {
        p[i] *= q[i];
    }
    dft(p, -1);
    for (int i = 1; i <= n; i++) {
        c[i] = p[i].real() / m;
    }
}
int main() {
    scanf("%d", &n);
    for (m = 1; m <= n << 1; m <<= 1) {
        l++;
    }
    for (int i = 0; i < m; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    }
    for (int i = 1; i <= n; i++) {
        scanf("%lf", f + i);
        g[i] = 1. / i / i;
    }
    fft(f, g, a);
    reverse(f + 1, f + n + 1);
    fft(f, g, b);
    for (int i = 1; i <= n; i++) {
        printf("%.3lf\n", a[i] - b[n - i + 1]);
    }
    return 0;
}

例题三:BZOJ 4827 礼物

一开始的思路是枚举加上的值,然后直接用 F F T 计算。

// Time Limit Exceeded
// Time Used: 30 Sec
#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
const int maxn = 50005;
const int maxm = 1 << 17 | 5;
const double pi = acos(-1.);
int n, m, bit, s, x[maxn], y[maxn], r[maxm];
struct cd {
    double r, i;
    cd() {}
    cd(double real, double imag) {
        r = real, i = imag;
    }
    cd operator+(cd x) {
        return cd(r + x.r, i + x.i);
    }
    cd operator-(cd x) {
        return cd(r - x.r, i - x.i);
    }
    cd operator*(cd x) {
        return cd(r * x.r - i * x.i, r * x.i + i * x.r);
    }
} a[maxm], b[maxm];
int max(int a, int b) {
    return a > b ? a : b;
}
void prework() {
    for (s = 1; s < n * 2; s <<= 1) bit++;
    for (int i = 0; i < s; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }   
}
void fft(cd *a, int dft) {
    for (int i = 0; i < s; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < s; k <<= 1) {
        cd wn0(cos(pi / k), dft * sin(pi / k));
        for (int i = 0; i < s; i += k << 1) {
            cd wnk(1, 0);
            for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y; 
            }
        }
    }
}
int solve(int *x, int *y) {
    for (int i = 0; i < s; i++) {
        a[i] = b[i] = cd(0, 0);
    }
    for (int i = 0; i < n * 2 - 1; i++) {
        a[i] = cd(x[i % n], 0);
    }
    for (int i = 0; i < n; i++) {
        b[i] = cd(y[n - 1 - i], 0);
    }
    fft(a, 1);
    fft(b, 1);
    for (int i = 0; i < s; i++) {
        a[i] = a[i] * b[i];
    }
    fft(a, -1);
    int sum = 0;
    for (int i = 0; i < n; i++) {
        sum += x[i] * x[i] + y[i] * y[i];
    }
    int res = int(a[n - 1].r / s + 0.5);
    for (int i = n; i < n * 2 - 1; i++) {
        int x = int(a[i].r / s + 0.5);
        res = max(res, x);
    }
    return sum - 2 * res;
}
int main() {
    scanf("%d %d", &n, &m);
    prework();
    for (int i = 0; i < n; i++) {
        scanf("%d", x + i);
    }
    for (int i = 0; i < n; i++) {
        scanf("%d", y + i);
    }
    int ans = solve(x, y);
    for (int i = 1; i <= m; i++) {
        for (int j = 0; j < n; j++) {
            x[j] += i;
        }
        ans = min(ans, solve(x, y));
        for (int j = 0; j < n; j++) {
            x[j] -= i;
        }
    }
    for (int i = 1; i <= m; i++) {
        for (int j = 0; j < n; j++) {
            y[j] += i;
        }
        ans = min(ans, solve(x, y));
        for (int j = 0; j < n; j++) {
            y[j] -= i;
        }
    }
    printf("%d\n", ans);
    return 0;
}

后来发现其实加上的值和之前的差异值是独立的。只需一次 F F T 即可。

#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
const int maxn = 50005;
const int maxm = 1 << 17 | 5;
const int inf = 1000000000;
const double pi = acos(-1.);
int n, m, bit, s, x[maxn], y[maxn], r[maxm];
struct cd {
    double r, i;
    cd() {}
    cd(double real, double imag) {
        r = real, i = imag;
    }
    cd operator+(cd x) {
        return cd(r + x.r, i + x.i);
    }
    cd operator-(cd x) {
        return cd(r - x.r, i - x.i);
    }
    cd operator*(cd x) {
        return cd(r * x.r - i * x.i, r * x.i + i * x.r);
    }
} a[maxm], b[maxm];
int max(int a, int b) {
    return a > b ? a : b;
}
int min(int a, int b) {
    return a < b ? a : b;
}
void prework() {
    for (s = 1; s < n * 2; s <<= 1) bit++;
    for (int i = 0; i < s; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    }   
}
void fft(cd *a, int dft) {
    for (int i = 0; i < s; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < s; k <<= 1) {
        cd wn0(cos(pi / k), dft * sin(pi / k));
        for (int i = 0; i < s; i += k << 1) {
            cd wnk(1, 0);
            for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y; 
            }
        }
    }
}
int solve(int *x, int *y) {
    for (int i = 0; i < s; i++) {
        a[i] = b[i] = cd(0, 0);
    }
    for (int i = 0; i < n * 2 - 1; i++) {
        a[i] = cd(x[i % n], 0);
    }
    for (int i = 0; i < n; i++) {
        b[i] = cd(y[n - 1 - i], 0);
    }
    fft(a, 1);
    fft(b, 1);
    for (int i = 0; i < s; i++) {
        a[i] = a[i] * b[i];
    }
    fft(a, -1);
    int sum = 0;
    for (int i = 0; i < n; i++) {
        sum += x[i] * x[i] + y[i] * y[i];
    }
    int res = int(a[n - 1].r / s + 0.5);
    for (int i = n; i < n * 2 - 1; i++) {
        int x = int(a[i].r / s + 0.5);
        res = max(res, x);
    }
    return sum - 2 * res;
}
int main() {
    scanf("%d %d", &n, &m);
    prework();
    for (int i = 0; i < n; i++) {
        scanf("%d", x + i);
    }
    for (int i = 0; i < n; i++) {
        scanf("%d", y + i);
    }
    int sum = 0, res = inf, ans = solve(x, y);
    for (int i = 0; i < n; i++) {
        sum += x[i] - y[i];
    }
    for (int i = -m; i <= m; i++) {
        res = min(res, n * i * i + 2 * sum * i);
    }
    printf("%d\n", ans + res);
    return 0;
}

例题四:BZOJ 3513 idiots

考虑算出不合法的方案数。记 c [ i ] 为长度为 i 的木棒的个数, s [ i ] 表示两根木棒长度加起来长度 i 的组数。 s [ i ] 可以用 F F T 计算。则方案数为: i = 1 m c [ i ] s [ i ] 。注意细节。

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int maxn = 1 << 18 | 5;
const double pi = acos(-1.);
int tasks, m, n, k, l, r[maxn], c[maxn];
struct cd {
    double r, i;
    cd() {
        r = 0, i = 0;
    }
    cd(double real, double imag) {
        r = real, i = imag;
    }
    cd operator+(cd x) {
        return cd(r + x.r, i + x.i);
    } 
    cd operator-(cd x) {
        return cd(r - x.r, i - x.i);
    } 
    cd operator*(cd x) {
        return cd(r * x.r - i * x.i, r * x.i + i * x.r);
    } 
} a[maxn];
void fft(cd *a, int op) {
    for (int i = 0; i < m; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < m; k <<= 1) {
        cd wn0(cos(pi / k), op * sin(pi / k));
        for (int i = 0; i < m; i += k << 1) {
            cd wnk = cd(1, 0);
            for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y;
            }
        }
    }
    if (op == -1) {
        for (int i = 0; i < m; i++) {
            a[i].r /= m;
        }
    }
}
int main() {
    for (scanf("%d", &tasks); tasks--; ) {
        memset(c, 0, sizeof(c));
        n = 0, scanf("%d", &k);
        for (int x, i = 1; i <= k; i++) {
            scanf("%d", &x);
            n = max(n, x), c[x]++;
        }
        for (m = 1, l = 0; m <= n << 1; m <<= 1)    l++;
        for (int i = 0; i < m; i++) {
            a[i] = cd(c[i], 0);
            r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
        }
        fft(a, 1);
        for (int i = 0; i < m; i++) {
            a[i] = a[i] * a[i];
        }
        fft(a, -1);
        ll ans = 0, sum = 0; 
        for (int i = 1; i <= n; i++) {
            sum += ll(a[i].r + 0.5);
            if (!(i & 1)) {
                sum -= c[i >> 1];
            }
            if (c[i]) {
                ans += sum * c[i];
            }
        }
        printf("%.7lf\n", 1. - ans * 3. / k / (k - 1) / (k - 2));
    }
    return 0;
}

例题五:【BZOJ 3771】Triple

我们讲一个悲伤的故事。
从前有一个贫穷的樵夫在河边砍柴。
这时候河里出现了一个水神,夺过了他的斧头,说:
“这把斧头,是不是你的?”
樵夫一看:“是啊是啊!”
水神把斧头扔在一边,又拿起一个东西问:
“这把斧头,是不是你的?”
樵夫看不清楚,但又怕真的是自己的斧头,只好又答:“是啊是啊!”
水神又把手上的东西扔在一边,拿起第三个东西问:
“这把斧头,是不是你的?”
樵夫还是看不清楚,但是他觉得再这样下去他就没法砍柴了。
于是他又一次答:“是啊是啊!真的是!”
水神看着他,哈哈大笑道:
“你看看你现在的样子,真是丑陋!”
之后就消失了。
樵夫觉得很坑爹,他今天不仅没有砍到柴,还丢了一把斧头给那个水神。

——选自题目描述

记价值为 i 的斧头个数为 A [ i ] = B [ 2 i ] = C [ 3 i ]
于是答案等于 A + A 2 B 2 + A 3 3 A B + 2 C 6

#include <cmath>
#include <cstdio>
#include <complex>
using namespace std;
typedef complex<double> cd;
const int maxn = 1 << 18 | 5;
const double pi = acos(-1.);
int n, m, l, r[maxn];
cd a[maxn], b[maxn], c[maxn], ans[maxn];
int max(int a, int b) {
    return a > b ? a : b;
}
void fft(cd *a, int op) {
    for (int i = 0; i < m; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int k = 1; k < m; k <<= 1) {
        cd wn0(cos(pi / k), op * sin(pi / k));
        for (int i = 0; i < m; i += k << 1) {
            cd wnk = 1;
            for (int j = i; j < i + k; j++, wnk *= wn0) {
                cd x = a[j], y = wnk * a[j + k];
                a[j] = x + y, a[j + k] = x - y;
            }
        }
    }
    if (op == -1) {
        for (int i = 0; i < m; i++) {
            a[i] /= m;
        }
    }
}
int main() {
    int q, x;
    for (scanf("%d", &q); q--; ) {
        scanf("%d", &x);
        a[x] += 1;
        b[x * 2] += 1;
        c[x * 3] += 1;
        n = max(n, x * 3);
    }
    for (m = 1; m <= n; m <<= 1) {
        l++;
    }
    for (int i = 0; i < m; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    }
    fft(a, 1);
    fft(b, 1);
    fft(c, 1);
    for (int i = 0; i < m; i++) {
        ans[i] = a[i] + (a[i] * a[i] - b[i]) / 2. + (a[i] * a[i] * a[i] - 3. * a[i] * b[i] + 2. * c[i]) / 6.;
    }
    fft(ans, -1);
    for (int i = 0; i < m; i++) {
        int x(ans[i].real() + 0.5);
        if (x) {
            printf("%d %d\n", i, x);
        }
    }
    return 0;
}

例题六:BZOJ 4259 残缺的字符串

如果模式串的第 i 为是 *,则 a [ i ] = 0 ,否则 a [ i ] =   i     i   。同理我们可以通过被匹配串得到 b [ ]

发现两个长度为 n 的串能够匹配的充要条件是:

i = 1 n a [ i ] b [ i ] ( a [ i ] b [ i ] ) = 0

a 串反转,用 F F T 计算卷积即可。

#include <cmath>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1 << 20 | 5;
const double pi = acos(-1.);
int n, m, k, l, r[maxn], p[maxn], q[maxn], o, w[maxn];
char s[maxn];
struct cd {
    double r, i;
    cd() {
        r = 0, i = 0;
    }
    cd(double real, double imag) {
        r = real, i = imag;
    }
    cd operator+(const cd &x) {
        return cd(r + x.r, i + x.i);
    }
    cd operator-(const cd &x) {
        return cd(r - x.r, i - x.i);
    }
    cd operator*(const cd &x) {
        return cd(r * x.r - i * x.i, r * x.i + i * x.r);
    }
} a[maxn], b[maxn], c[maxn];
void fft(cd *a, int dft) {
    for (int i = 0; i < k; i++) {
        if (i < r[i]) {
            swap(a[i], a[r[i]]);
        }
    }
    for (int step = 1; step < k; step <<= 1) {
        cd wn0(cos(pi / step), dft * sin(pi / step));
        for (int i = 0; i < k; i += step << 1) {
            cd wnk(1, 0);
            for (int j = i; j < i + step; j++, wnk = wnk * wn0) {
                cd x = a[j], y = wnk * a[j + step];
                a[j] = x + y, a[j + step] = x - y;
            }
        }
    }
}
int main() {
    scanf("%d %d", &n, &m);
    scanf("%s", s);
    reverse(s, s + n);
    for (int i = 0; i < n; i++) {
        p[i] = s[i] == '*' ? 0 : s[i] - 'a' + 1;
    }
    scanf("%s", s);
    for (int i = 0; i < m; i++) {
        q[i] = s[i] == '*' ? 0 : s[i] - 'a' + 1;
    }
    for (k = 1; k < n + m; k <<= 1) l++;
    for (int i = 0; i < k; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    }
    for (int i = 0; i < k; i++) {
        a[i] = cd(p[i] * p[i] * p[i], 0);
        b[i] = cd(q[i], 0);
    }
    fft(a, 1), fft(b, 1);
    for (int i = 0; i < k; i++) {
        c[i] = c[i] + a[i] * b[i];
    }
    for (int i = 0; i < k; i++) {
        a[i] = cd(p[i] * p[i], 0);
        b[i] = cd(q[i] * q[i], 0);
    }
    fft(a, 1), fft(b, 1);
    for (int i = 0; i < k; i++) {
        c[i] = c[i] - cd(2, 0) * a[i] * b[i];
    }
    for (int i = 0; i < k; i++) {
        a[i] = cd(p[i], 0);
        b[i] = cd(q[i] * q[i] * q[i], 0);
    }
    fft(a, 1), fft(b, 1);
    for (int i = 0; i < k; i++) {
        c[i] = c[i] + a[i] * b[i];
    }
    fft(c, -1);
    for (int i = n - 1; i < m; i++) {
        if (c[i].r / k < 0.5) {
            w[++o] = i - n + 2;
        }
    }
    printf("%d\n", o);
    for (int i = 1; i <= o; i++) {
        printf("%d%c", w[i], i == o ? '\n' : ' ');
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_42068627/article/details/81065249