个人见解: 是一种能在 时间内求多项式的乘积的算法,代码复杂度较小,缺点是常数巨大。
算法思路:我们发现点值表示法的多项式乘法是
的,于是我们可以绕过
而是使用
来求解多项式乘法。
现在的关键问题就是如何快速进行 和 操作。
算法流程:见 algocode 算法博客。
// Template
#include <cmath>
#include <cstdio>
#include <complex>
using namespace std;
typedef complex<double> cd;
const int maxn = 262145;
const double pi = acos(-1.);
cd a[maxn], b[maxn];
int n, m, s, bit, rev[maxn];
void prework() {
for (bit = 1, s = 2; 1 << bit <= n + m; bit++) {
s <<= 1;
}
for (int i = 0; i < 1 << bit; i++) {
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
}
void dft(cd *a, int n, int op) {
for (int i = 0; i < n; i++) {
if (i < rev[i]) {
swap(a[i], a[rev[i]]);
}
}
for (int step = 1; step < n; step <<= 1) {
cd wn0 = exp(cd(0, op * pi / step));
for (int j = 0; j < n; j += step << 1) {
cd wnk(1, 0);
for (int k = j; k < j + step; k++) {
cd x = a[k];
cd y = wnk * a[k + step];
a[k] = x + y;
a[k + step] = x - y;
wnk *= wn0;
}
}
}
if (op == -1) {
for (int i = 0; i < n; i++) {
a[i] /= n;
}
}
}
int main() {
scanf("%d %d", &n, &m);
prework();
for (int x, i = 0; i <= n; i++) {
scanf("%d", &x);
a[i] = cd(x, 0);
}
for (int x, i = 0; i <= m; i++) {
scanf("%d", &x);
b[i] = cd(x, 0);
}
dft(a, s, 1), dft(b, s, 1);
for (int i = 0; i < s; i++) {
a[i] *= b[i];
}
dft(a, s, -1); for (int i = 0; i <= n + m; i++) {
printf("%d%c", int(a[i].real() + 0.5), " \n"[i == n + m]);
}
return 0;
}
这题几乎是一道模版题。只需将 数组反向计算两个多项式的乘法即可。
#include <cmath>
#include <cstdio>
#include <complex>
#include <algorithm>
using namespace std;
const int maxn = 1 << 18 | 5;
const double pi = acos(-1.);
typedef complex<double> cd;
int n, m, l, r[maxn];
cd a[maxn], b[maxn];
void fft(cd *a, int op) {
for (int i = 0; i < m; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < m; k <<= 1) {
cd wn0(cos(pi / k), op * sin(pi / k));
for (int i = 0; i < m; i += k << 1) {
cd wnk = 1;
for (int j = i; j < i + k; j++, wnk *= wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
if (op == -1) {
for (int i = 0; i < m; i++) {
a[i] /= m;
}
}
}
int main() {
scanf("%d", &n);
for (int x, y, i = 0; i < n; i++) {
scanf("%d %d", &x, &y);
a[i] = x, b[i] = y;
}
for (m = 1; m < n << 1; m <<= 1) {
l++;
}
for (int i = 0; i < m; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
reverse(a, a + n);
fft(a, 1);
fft(b, 1);
for (int i = 0; i < m; i++) {
a[i] *= b[i];
}
fft(a, -1);
for (int i = n - 1; ~i; i--) {
printf("%d\n", int(a[i].real() + 0.5));
}
return 0;
}
例题二:【BZ3527】力
推得
记 ,则 。用 计算即可。
#include <cmath>
#include <cstdio>
#include <complex>
#include <iostream>
#include <algorithm>
using namespace std;
typedef double db;
typedef complex<db> cd;
const int maxn = 1 << 18 | 5;
const db pi = acos(-1.);
int n, m, l, r[maxn];
db f[maxn], g[maxn], a[maxn], b[maxn];
void dft(cd *a, int op) {
for (int i = 0; i < m; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < m; k <<= 1) {
cd wn0(cos(pi / k), op * sin(pi / k));
for (int i = 0; i < m; i += k << 1) {
cd wnk(1, 0);
for (int j = i; j < i + k; j++, wnk *= wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
}
void fft(db *a, db *b, db *c) {
static cd p[maxn], q[maxn];
for (int i = 0; i < m; i++) {
p[i] = a[i], q[i] = b[i];
}
dft(p, 1), dft(q, 1);
for (int i = 0; i < m; i++) {
p[i] *= q[i];
}
dft(p, -1);
for (int i = 1; i <= n; i++) {
c[i] = p[i].real() / m;
}
}
int main() {
scanf("%d", &n);
for (m = 1; m <= n << 1; m <<= 1) {
l++;
}
for (int i = 0; i < m; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
for (int i = 1; i <= n; i++) {
scanf("%lf", f + i);
g[i] = 1. / i / i;
}
fft(f, g, a);
reverse(f + 1, f + n + 1);
fft(f, g, b);
for (int i = 1; i <= n; i++) {
printf("%.3lf\n", a[i] - b[n - i + 1]);
}
return 0;
}
例题三:BZOJ 4827 礼物
一开始的思路是枚举加上的值,然后直接用 计算。
// Time Limit Exceeded
// Time Used: 30 Sec
#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
const int maxn = 50005;
const int maxm = 1 << 17 | 5;
const double pi = acos(-1.);
int n, m, bit, s, x[maxn], y[maxn], r[maxm];
struct cd {
double r, i;
cd() {}
cd(double real, double imag) {
r = real, i = imag;
}
cd operator+(cd x) {
return cd(r + x.r, i + x.i);
}
cd operator-(cd x) {
return cd(r - x.r, i - x.i);
}
cd operator*(cd x) {
return cd(r * x.r - i * x.i, r * x.i + i * x.r);
}
} a[maxm], b[maxm];
int max(int a, int b) {
return a > b ? a : b;
}
void prework() {
for (s = 1; s < n * 2; s <<= 1) bit++;
for (int i = 0; i < s; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
}
void fft(cd *a, int dft) {
for (int i = 0; i < s; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < s; k <<= 1) {
cd wn0(cos(pi / k), dft * sin(pi / k));
for (int i = 0; i < s; i += k << 1) {
cd wnk(1, 0);
for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
}
int solve(int *x, int *y) {
for (int i = 0; i < s; i++) {
a[i] = b[i] = cd(0, 0);
}
for (int i = 0; i < n * 2 - 1; i++) {
a[i] = cd(x[i % n], 0);
}
for (int i = 0; i < n; i++) {
b[i] = cd(y[n - 1 - i], 0);
}
fft(a, 1);
fft(b, 1);
for (int i = 0; i < s; i++) {
a[i] = a[i] * b[i];
}
fft(a, -1);
int sum = 0;
for (int i = 0; i < n; i++) {
sum += x[i] * x[i] + y[i] * y[i];
}
int res = int(a[n - 1].r / s + 0.5);
for (int i = n; i < n * 2 - 1; i++) {
int x = int(a[i].r / s + 0.5);
res = max(res, x);
}
return sum - 2 * res;
}
int main() {
scanf("%d %d", &n, &m);
prework();
for (int i = 0; i < n; i++) {
scanf("%d", x + i);
}
for (int i = 0; i < n; i++) {
scanf("%d", y + i);
}
int ans = solve(x, y);
for (int i = 1; i <= m; i++) {
for (int j = 0; j < n; j++) {
x[j] += i;
}
ans = min(ans, solve(x, y));
for (int j = 0; j < n; j++) {
x[j] -= i;
}
}
for (int i = 1; i <= m; i++) {
for (int j = 0; j < n; j++) {
y[j] += i;
}
ans = min(ans, solve(x, y));
for (int j = 0; j < n; j++) {
y[j] -= i;
}
}
printf("%d\n", ans);
return 0;
}
后来发现其实加上的值和之前的差异值是独立的。只需一次 即可。
#include <cmath>
#include <cstdio>
#include <iostream>
using namespace std;
const int maxn = 50005;
const int maxm = 1 << 17 | 5;
const int inf = 1000000000;
const double pi = acos(-1.);
int n, m, bit, s, x[maxn], y[maxn], r[maxm];
struct cd {
double r, i;
cd() {}
cd(double real, double imag) {
r = real, i = imag;
}
cd operator+(cd x) {
return cd(r + x.r, i + x.i);
}
cd operator-(cd x) {
return cd(r - x.r, i - x.i);
}
cd operator*(cd x) {
return cd(r * x.r - i * x.i, r * x.i + i * x.r);
}
} a[maxm], b[maxm];
int max(int a, int b) {
return a > b ? a : b;
}
int min(int a, int b) {
return a < b ? a : b;
}
void prework() {
for (s = 1; s < n * 2; s <<= 1) bit++;
for (int i = 0; i < s; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (bit - 1));
}
}
void fft(cd *a, int dft) {
for (int i = 0; i < s; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < s; k <<= 1) {
cd wn0(cos(pi / k), dft * sin(pi / k));
for (int i = 0; i < s; i += k << 1) {
cd wnk(1, 0);
for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
}
int solve(int *x, int *y) {
for (int i = 0; i < s; i++) {
a[i] = b[i] = cd(0, 0);
}
for (int i = 0; i < n * 2 - 1; i++) {
a[i] = cd(x[i % n], 0);
}
for (int i = 0; i < n; i++) {
b[i] = cd(y[n - 1 - i], 0);
}
fft(a, 1);
fft(b, 1);
for (int i = 0; i < s; i++) {
a[i] = a[i] * b[i];
}
fft(a, -1);
int sum = 0;
for (int i = 0; i < n; i++) {
sum += x[i] * x[i] + y[i] * y[i];
}
int res = int(a[n - 1].r / s + 0.5);
for (int i = n; i < n * 2 - 1; i++) {
int x = int(a[i].r / s + 0.5);
res = max(res, x);
}
return sum - 2 * res;
}
int main() {
scanf("%d %d", &n, &m);
prework();
for (int i = 0; i < n; i++) {
scanf("%d", x + i);
}
for (int i = 0; i < n; i++) {
scanf("%d", y + i);
}
int sum = 0, res = inf, ans = solve(x, y);
for (int i = 0; i < n; i++) {
sum += x[i] - y[i];
}
for (int i = -m; i <= m; i++) {
res = min(res, n * i * i + 2 * sum * i);
}
printf("%d\n", ans + res);
return 0;
}
例题四:BZOJ 3513 idiots
考虑算出不合法的方案数。记 为长度为 的木棒的个数, 表示两根木棒长度加起来长度 的组数。 可以用 计算。则方案数为: 。注意细节。
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
using namespace std;
typedef long long ll;
const int maxn = 1 << 18 | 5;
const double pi = acos(-1.);
int tasks, m, n, k, l, r[maxn], c[maxn];
struct cd {
double r, i;
cd() {
r = 0, i = 0;
}
cd(double real, double imag) {
r = real, i = imag;
}
cd operator+(cd x) {
return cd(r + x.r, i + x.i);
}
cd operator-(cd x) {
return cd(r - x.r, i - x.i);
}
cd operator*(cd x) {
return cd(r * x.r - i * x.i, r * x.i + i * x.r);
}
} a[maxn];
void fft(cd *a, int op) {
for (int i = 0; i < m; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < m; k <<= 1) {
cd wn0(cos(pi / k), op * sin(pi / k));
for (int i = 0; i < m; i += k << 1) {
cd wnk = cd(1, 0);
for (int j = i; j < i + k; j++, wnk = wnk * wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
if (op == -1) {
for (int i = 0; i < m; i++) {
a[i].r /= m;
}
}
}
int main() {
for (scanf("%d", &tasks); tasks--; ) {
memset(c, 0, sizeof(c));
n = 0, scanf("%d", &k);
for (int x, i = 1; i <= k; i++) {
scanf("%d", &x);
n = max(n, x), c[x]++;
}
for (m = 1, l = 0; m <= n << 1; m <<= 1) l++;
for (int i = 0; i < m; i++) {
a[i] = cd(c[i], 0);
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
fft(a, 1);
for (int i = 0; i < m; i++) {
a[i] = a[i] * a[i];
}
fft(a, -1);
ll ans = 0, sum = 0;
for (int i = 1; i <= n; i++) {
sum += ll(a[i].r + 0.5);
if (!(i & 1)) {
sum -= c[i >> 1];
}
if (c[i]) {
ans += sum * c[i];
}
}
printf("%.7lf\n", 1. - ans * 3. / k / (k - 1) / (k - 2));
}
return 0;
}
我们讲一个悲伤的故事。
从前有一个贫穷的樵夫在河边砍柴。
这时候河里出现了一个水神,夺过了他的斧头,说:
“这把斧头,是不是你的?”
樵夫一看:“是啊是啊!”
水神把斧头扔在一边,又拿起一个东西问:
“这把斧头,是不是你的?”
樵夫看不清楚,但又怕真的是自己的斧头,只好又答:“是啊是啊!”
水神又把手上的东西扔在一边,拿起第三个东西问:
“这把斧头,是不是你的?”
樵夫还是看不清楚,但是他觉得再这样下去他就没法砍柴了。
于是他又一次答:“是啊是啊!真的是!”
水神看着他,哈哈大笑道:
“你看看你现在的样子,真是丑陋!”
之后就消失了。
樵夫觉得很坑爹,他今天不仅没有砍到柴,还丢了一把斧头给那个水神。
——选自题目描述
记价值为
的斧头个数为
。
于是答案等于
。
#include <cmath>
#include <cstdio>
#include <complex>
using namespace std;
typedef complex<double> cd;
const int maxn = 1 << 18 | 5;
const double pi = acos(-1.);
int n, m, l, r[maxn];
cd a[maxn], b[maxn], c[maxn], ans[maxn];
int max(int a, int b) {
return a > b ? a : b;
}
void fft(cd *a, int op) {
for (int i = 0; i < m; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int k = 1; k < m; k <<= 1) {
cd wn0(cos(pi / k), op * sin(pi / k));
for (int i = 0; i < m; i += k << 1) {
cd wnk = 1;
for (int j = i; j < i + k; j++, wnk *= wn0) {
cd x = a[j], y = wnk * a[j + k];
a[j] = x + y, a[j + k] = x - y;
}
}
}
if (op == -1) {
for (int i = 0; i < m; i++) {
a[i] /= m;
}
}
}
int main() {
int q, x;
for (scanf("%d", &q); q--; ) {
scanf("%d", &x);
a[x] += 1;
b[x * 2] += 1;
c[x * 3] += 1;
n = max(n, x * 3);
}
for (m = 1; m <= n; m <<= 1) {
l++;
}
for (int i = 0; i < m; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
fft(a, 1);
fft(b, 1);
fft(c, 1);
for (int i = 0; i < m; i++) {
ans[i] = a[i] + (a[i] * a[i] - b[i]) / 2. + (a[i] * a[i] * a[i] - 3. * a[i] * b[i] + 2. * c[i]) / 6.;
}
fft(ans, -1);
for (int i = 0; i < m; i++) {
int x(ans[i].real() + 0.5);
if (x) {
printf("%d %d\n", i, x);
}
}
return 0;
}
例题六:BZOJ 4259 残缺的字符串
如果模式串的第
为是 *
,则
,否则
。同理我们可以通过被匹配串得到
。
发现两个长度为
的串能够匹配的充要条件是:
将 串反转,用 计算卷积即可。
#include <cmath>
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int maxn = 1 << 20 | 5;
const double pi = acos(-1.);
int n, m, k, l, r[maxn], p[maxn], q[maxn], o, w[maxn];
char s[maxn];
struct cd {
double r, i;
cd() {
r = 0, i = 0;
}
cd(double real, double imag) {
r = real, i = imag;
}
cd operator+(const cd &x) {
return cd(r + x.r, i + x.i);
}
cd operator-(const cd &x) {
return cd(r - x.r, i - x.i);
}
cd operator*(const cd &x) {
return cd(r * x.r - i * x.i, r * x.i + i * x.r);
}
} a[maxn], b[maxn], c[maxn];
void fft(cd *a, int dft) {
for (int i = 0; i < k; i++) {
if (i < r[i]) {
swap(a[i], a[r[i]]);
}
}
for (int step = 1; step < k; step <<= 1) {
cd wn0(cos(pi / step), dft * sin(pi / step));
for (int i = 0; i < k; i += step << 1) {
cd wnk(1, 0);
for (int j = i; j < i + step; j++, wnk = wnk * wn0) {
cd x = a[j], y = wnk * a[j + step];
a[j] = x + y, a[j + step] = x - y;
}
}
}
}
int main() {
scanf("%d %d", &n, &m);
scanf("%s", s);
reverse(s, s + n);
for (int i = 0; i < n; i++) {
p[i] = s[i] == '*' ? 0 : s[i] - 'a' + 1;
}
scanf("%s", s);
for (int i = 0; i < m; i++) {
q[i] = s[i] == '*' ? 0 : s[i] - 'a' + 1;
}
for (k = 1; k < n + m; k <<= 1) l++;
for (int i = 0; i < k; i++) {
r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
}
for (int i = 0; i < k; i++) {
a[i] = cd(p[i] * p[i] * p[i], 0);
b[i] = cd(q[i], 0);
}
fft(a, 1), fft(b, 1);
for (int i = 0; i < k; i++) {
c[i] = c[i] + a[i] * b[i];
}
for (int i = 0; i < k; i++) {
a[i] = cd(p[i] * p[i], 0);
b[i] = cd(q[i] * q[i], 0);
}
fft(a, 1), fft(b, 1);
for (int i = 0; i < k; i++) {
c[i] = c[i] - cd(2, 0) * a[i] * b[i];
}
for (int i = 0; i < k; i++) {
a[i] = cd(p[i], 0);
b[i] = cd(q[i] * q[i] * q[i], 0);
}
fft(a, 1), fft(b, 1);
for (int i = 0; i < k; i++) {
c[i] = c[i] + a[i] * b[i];
}
fft(c, -1);
for (int i = n - 1; i < m; i++) {
if (c[i].r / k < 0.5) {
w[++o] = i - n + 2;
}
}
printf("%d\n", o);
for (int i = 1; i <= o; i++) {
printf("%d%c", w[i], i == o ? '\n' : ' ');
}
return 0;
}