我们已经知道了可以使用FFT和NTT在
的时间内计算多项式乘法,那对于其它的运算呢?
首先,我们需要知道多项式逆元这个根本性的问题。
多项式求逆
考虑两个多项式
,那么我们称
意义下的逆元,即它们乘积的前
项只有常数项为1,其它都是0.
那么我们如何计算逆元呢?考虑如果我们已经计算出了
,那么我们能不能推导出满足
的多项式
呢?
首先将两式相减,可以得到
由于等式左边多项式的前
项系数都是0,可以把这个玩意儿平方:
又由于
,带入可以得到:
提取多项式
:
即
。
于是我们会发现一个多项式是否有逆元完全取决于该多项式的常数项是否有逆元。由于这个算法每次可以将问题规模减半,总的时间复杂度就是
例题 洛谷3711 仓鼠的数学题
题目链接
首先可以想到使用伯努利数来化简算式。于是就有
最开始有个1是因为题目中把
也算为1.然后我们考虑继续化简:
由于我们要求的是多项式,考虑枚举x的指数:
然后就是很经典的翻转卷积模型了。
于是后面就变成了显然的一个卷积,直接NTT求解即可。那我们如何求出伯努利数呢?
考虑伯努利数的指数型生成函数:
因此一个多项式求逆就可以求出伯努利数了,总复杂度
。
(就这里贴一个代码,放上求逆元的板子,后面就不放代码了……)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 19 | 5, mod = 998244353, G = 3;
int modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(int *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(int *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = w * a[j + hh] % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
ll inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = a[i] * inv % mod;
}
}
int B[maxn], a[maxn], fact[maxn], f[maxn], g[maxn], temp[maxn], rev[maxn], n;
void get_inv(int *a, int *b, int n){
if(n == 1){
b[0] = modpow(a[0], mod - 2);
return;
}
get_inv(a, b, n >> 1);
int t = n << 1;
for(int i = 0; i < n; i++) temp[i] = a[i];
for(int i = n; i < t; i++) temp[i] = 0;
ntt(temp, t, 0), ntt(b, t, 0);
for(int i = 0; i < t; i++)
temp[i] = (ll)b[i] * (2 - (ll)b[i] * temp[i] % mod + mod) % mod;
ntt(temp, t, 1);
for(int i = 0; i < n; i++) b[i] = temp[i];
for(int i = n; i < t; i++) b[i] = 0;
}
int main(){
scanf("%d", &n);
fact[0] = 1;
int t = 1;
while(t <= n) t <<= 1;
for(int i = 1; i <= t; i++) fact[i] = (ll)fact[i - 1] * i % mod;
rev[t] = modpow(fact[t], mod - 2);
for(int i = t - 1; i >= 0; i--) rev[i] = (ll)rev[i + 1] * (i + 1) % mod;
for(int i = 0; i < t; i++) a[i] = rev[i + 1];
get_inv(a, B, t);
for(int i = 0; i <= n; i++) scanf("%d", a + i);
for(int i = 0; i < t; i++) B[i] = (ll)B[i] * fact[i] % mod;
for(int i = 0; i <= n; i++){
f[i] = (ll)B[i] * rev[i] % mod;
if(i & 1) f[i] = mod - f[i];
g[i] = (ll)a[n - i] * fact[n - i] % mod;
}
printf("%d ", a[0]);
ntt(f, t << 1, 0), ntt(g, t << 1, 0);
for(int i = 0; i < t << 1; i++) f[i] = (ll)f[i] * g[i] % mod;
ntt(f, t << 1, 1);
for(int i = 1; i <= n + 1; i++)
printf("%d ", (ll)f[n + 1 - i] * rev[i] % mod);
return 0;
}
例题 BZOJ2259异化多肽
题目链接
考虑令
表示分子量为
时的组合方案数。我们可以考虑枚举最后一个原子量是什么,令
为原子量集合,于是可以得到如下递推式:
再令
,即每个数字是否在集合中出现过,出现过为1,否则为0.于是递推式就可以改写成:
令
,则有
,移项可以得到
于是直接多项式求逆即可。复杂度
。
例题 51nod1514美妙的序列
题目链接
首先暴力dp肯定是n方的,直接舍弃掉。我们需要发现这一题的隐含性质。
首先来考虑一下什么时候这个序列不满足要求,也就是说我们可以找到一个分割点使得右边的最小值大于左边的最大值。于是左边必然是某一个1~k的排列;并且,如果能找到某个位置满足这个位置左边是一个1~k的排列,那么这个序列一定不满足要求。
也就是说,一个序列是美妙的当且仅当它的任何一个前缀都不是1~k的排列。
于是我们设
表示长度为n的答案,那么我们可以考虑反过来求,也就是求所有不美妙的序列。于是枚举最右端可以构成1~k的排列的部分(也就是最靠右不能满足要求的位置),然后把序列剖成两半,左边一半必然是一个排列,右边一半必然是一个美妙的序列(如果不是,就意味着还存在一个更靠右的排列)。
于是就可以dp了:
也就是可以得到:
由于有
的特殊情况,我们还有
于是令
的生成函数,
,于是就有
,移项可以得到:
多项式求逆即可,复杂度
。
多项式求ln
什么?多项式能算无理数?玄妙啊……不说这么多了,我们怎么对一个多项式
呢?
我们考虑对结果求一个导数,就可以得到
,后面那个东西是可以直接多项式求逆算出来的,于是算完之后再求一个原函数(积分一下)就OK了。由于用到一次多项式乘法和多项式求逆,复杂度
。
例题 洛谷P3784 [SDOI2017]遗忘的集合
题目链接
我们考虑普通情况下是如何通过集合S求出f函数的。构造生成函数(n为集合大小):
换到这道题目中,我们已知
,于是求出满足要求的
函数:
考虑两边同时取对数,得到
使用泰勒展开式展开ln函数,得到
枚举
的值,得到:
于是我们只要对函数
求一个ln就可以得到右边多项式的系数,不妨假设得到的第
项系数为
。于是得到
这不是一个显然的莫比乌斯反演嘛!!!于是莫比乌斯函数搞一发:
然后这道题就做完了……复杂度
。
也是附一个多项式求ln的板子吧……
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 19 | 5;
const double PI = acos(-1);
struct C{
double x, y;
C operator+(const C &c) const {return (C){x + c.x, y + c.y};}
C operator-(const C &c) const {return (C){x - c.x, y - c.y};}
C operator*(const C &c) const {return (C){x * c.x - y * c.y, x * c.y + y * c.x};}
} tab[maxn], ffta[maxn], fftb[maxn], fftc[maxn];
void rader(C *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void fft(C *a, int n, int f){
rader(a, n);
for(int vis = 2; vis <= n; vis <<= 1){
int hh = vis >> 1;
for(int i = 0; i < n; i += vis)
for(int j = i; j < i + hh; j++){
const C x = a[j], y = a[j + hh] * tab[n / vis * (j - i)];
a[j] = x + y, a[j + hh] = x - y;
}
}
if(f == -1){
C inv = (C){1.0 / n, 0};
for(int i = 0; i < n; i++) a[i] = a[i] * inv;
}
}
void mul(int *a, int *b, int n, int p){
int len = n << 1;
memset(ffta, 0, sizeof(C) * len);
memset(fftb, 0, sizeof(C) * len);
memset(fftc, 0, sizeof(C) * len);
for(int i = 0; i < n; i++){
a[i] %= p, b[i] %= p;
ffta[i] = (C){a[i] >> 15, a[i] & 0x7FFF};
fftb[i] = (C){b[i] >> 15, b[i] & 0x7FFF};
}
for(int i = 0; i < len; i++)
tab[i] = (C){cos(2 * i * PI / len), sin(2 * i * PI / len)};
fft(ffta, len, 1), fft(fftb, len, 1);
for(int i = 0; i < len; i++){
int j = (len - i) & (len - 1);
C p = (C){(ffta[i].x + ffta[j].x) * 0.5, (ffta[i].y - ffta[j].y) * 0.5};
C q = (C){(fftb[i].x + fftb[j].x) * 0.5, (fftb[i].y - fftb[j].y) * 0.5};
fftc[i] = p * q;
tab[i].y = -tab[i].y;
}
for(int i = 0; i < len; i++) ffta[i] = ffta[i] * fftb[i];
fft(ffta, len, -1), fft(fftc, len, -1);
for(int i = 0; i < len; i++){
ll x = ((ll)(fftc[i].x + 0.5) % p << 30) % p;
ll y = ((ll)(ffta[i].y + 0.5) % p << 15) % p;
ll z = (ll)(fftc[i].x - ffta[i].x + 0.5) % p;
a[i] = (x + y + z + p) % p;
}
}
int modpow(ll a, int b, int c){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % c;
a = a * a % c;
}
return res;
}
int a[maxn], b[maxn], temp[maxn], n, p, len;
void get_inv(int *a, int *b, int n, int p){
if(n == 1) return (void)(b[0] = modpow(a[0], p - 2, p));
get_inv(a, b, n >> 1, p);
for(int i = 0; i < n; i++) temp[i] = 2 * b[i] % p;
mul(b, b, n >> 1, p);
mul(b, a, n, p);
for(int i = 0; i < n; i++) b[i] = (temp[i] - b[i] + p) % p;
for(int i = n; i < n << 1; i++) b[i] = 0;
}
void get_ln(int *a, int *b, int n, int p){
get_inv(a, b, n, p);
for(int i = 0; i < n - 1; i++)
a[i] = (ll)a[i + 1] * (i + 1) % p;
a[n - 1] = 0;
mul(b, a, n, p);
for(int i = n - 2; i >= 0; i--)
b[i + 1] = (ll)b[i] * modpow(i + 1, p - 2, p) % p;
b[0] = 0;
}
int vis[maxn], mu[maxn], prime[maxn];
int main(){
scanf("%d%d", &n, &p);
for(len = 1; len <= n; len <<= 1);
a[0] = 1;
for(int i = 1; i <= n; i++) scanf("%d", a + i);
get_ln(a, b, len, p);
mu[1] = 1;
for(int i = 2, cnt = 0; i <= n; i++){
if(!vis[i]){
mu[i] = -1;
prime[cnt++] = i;
}
for(int j = 0; j < cnt; j++){
int p = prime[j], mul = i * p;
if(mul > n) break;
vis[mul] = 1;
if(i % p == 0) break;
mu[mul] = -mu[i];
}
}
memset(vis, 0, sizeof(vis));
int cnt = 0;
for(int i = 1; i <= n; i++){
for(int j = 1; j * i <= n; j++)
vis[j * i] = ((ll)(mu[i] + p) * b[j] % p * j + vis[j * i]) % p;
if(vis[i] > 0) ++cnt;
}
printf("%d\n", cnt);
for(int i = 1; i <= n; i++) if(vis[i]) printf("%d ", i);
return 0;
}
多项式求exp
上面既然介绍过了多项式求ln,那么逆运算应该也可以的吧……令
,那么就有
。于是考虑牛顿迭代,就有:
每次迭代都使精度(次数)翻倍,因此复杂度
,常数巨大,实测是一次fft的20倍以上。
###例题 洛谷P4389 付公主的背包
题目链接
显然我们也可以像上题那样定义生成函数:
技巧差不多,两边同时求对数并展开:
如果v互不相同,那么后面那个求和可以在
的时间内用类似调和级数的复杂度预处理出来;但是如果v有相同的怎么办?如果含有相同的v就一起算呗……于是复杂度还是
的。求完那个系数之后再做一遍exp就可以求出
了。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 18 | 5, mod = 998244353, G = 3;
int modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(int *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(int *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = a[j + hh] * w % mod;
a[j] = (x + y) % mod;
a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
int inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = (ll)a[i] * inv % mod;
}
}
int temp[maxn], aa[maxn], bb[maxn], ln[maxn], n, m;
void get_inv(int *a, int *b, int n){
if(n == 1) return (void)(b[0] = modpow(a[0], mod - 2), b[1] = 0);
get_inv(a, b, n >> 1);
for(int i = 0; i < n; i++) temp[i] = a[i];
int t = n << 1;
for(int i = n; i < t; i++) temp[i] = b[i] = 0;
ntt(temp, t, 0), ntt(b, t, 0);
for(int i = 0; i < t; i++)
b[i] = (2 - (ll)b[i] * temp[i] % mod + mod) * b[i] % mod;
ntt(b, t, 1);
for(int i = n; i < t; i++) b[i] = 0;
}
void get_ln(int *a, int *b, int n){
int t = n << 1;
get_inv(a, b, n);
for(int i = 0; i < n; i++) temp[i] = a[i];
for(int i = n; i < t; i++) temp[i] = 0;
a = temp;
for(int i = 0; i < n - 1; i++)
a[i] = (ll)a[i + 1] * (i + 1) % mod;
a[n] = 0;
ntt(a, t, 0), ntt(b, t, 0);
for(int i = 0; i < t; i++) b[i] = (ll)b[i] * a[i] % mod;
ntt(b, t, 1);
for(int i = n - 2; i >= 0; i--)
b[i + 1] = (ll)b[i] * modpow(i + 1, mod - 2) % mod;
b[0] = 0;
for(int i = n; i < t; i++) b[i] = 0;
}
void get_exp(int *a, int *b, int n){
if(n == 1) return (void)(b[0] = 1);
get_exp(a, b, n >> 1);
get_ln(b, ln, n);
int t = n << 1;
for(int i = 0; i < n; i++){
temp[i] = (!i + mod - ln[i] + a[i]) % mod;
temp[i + n] = 0;
}
ntt(temp, t, 0), ntt(b, t, 0);
for(int i = 0; i < t; i++) b[i] = (ll)b[i] * temp[i] % mod;
ntt(b, t, 1);
for(int i = n; i < t; i++) b[i] = 0;
}
int v[maxn], rev[maxn], rpos, ppos;
const int maxr = 2000000;
char str[maxr], prt[maxr];
char readc(){
if(!rpos) fread(str, 1, maxr, stdin);
char c = str[rpos++];
if(rpos == maxr) rpos = 0;
return c;
}
int read(){
int x; char c;
while((c = readc()) < '0' || c > '9');
x = c - '0';
while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
return x;
}
void print(int x){
if(x > 0){
char sta[10];
int tp = 0;
for(; x; x /= 10) sta[tp++] = x % 10 + '0';
while(tp) prt[ppos++] = sta[--tp];
} else prt[ppos++] = '0';
prt[ppos++] = '\n';
}
int main(){
n = read(), m = read();
for(int i = 0; i < n; i++) v[i] = read();
sort(v, v + n);
rev[1] = 1;
for(int i = 2; i <= m; i++)
rev[i] = mod - (ll)mod / i * rev[mod % i] % mod;
for(int i = 0; i < n;){
int j = 0;
while(i + j < n && v[i + j] == v[i]) ++j;
for(int k = v[i]; k <= m; k += v[i])
aa[k] = (aa[k] + (ll)j * rev[k / v[i]]) % mod;
i = i + j;
}
int len = 1;
while(len <= m) len <<= 1;
get_exp(aa, bb, len);
for(int i = 1; i <= m; i++) print(bb[i]);
fwrite(prt, 1, ppos, stdout);
return 0;
}
一般情况的多项式运算
在上面一节,我们通过牛顿迭代使用ln计算出了exp。其实类似的,使用牛顿迭代可以使用一个运算求出其逆运算。
下面是牛顿迭代的一般公式:
不断迭代即可求出
的根。比如多项式开方:令
迭代:
于是多项式开方就在
的时间内被解决,牛顿迭代可以活用于各种场景。
例题 CF438E The Child and Binary Tree
题目链接
我们考虑dp,设
表示权值和为
时的方案数,显然我们可以枚举根节点的权值,再枚举左右子树的权值:
其中C是题目给的正整数集合。一样的套路,设
中出现过,
那么dp方程就可以写成:
因此我们就可以定义它们的生成函数分别为
最后面那个形式就可以牛顿迭代了:
于是在
的时间内也可以求解了。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1 << 18 | 5, mod = 998244353, G = 3, inv2 = (mod + 1) >> 1;
int modpow(ll a, int b){
ll res = 1;
for(; b; b >>= 1){
if(b & 1) res = res * a % mod;
a = a * a % mod;
}
return res;
}
void rader(int *a, int n){
for(int i = 1, j = n >> 1; i < n - 1; i++){
if(i < j) swap(a[i], a[j]);
int k = n >> 1;
for(; j >= k; k >>= 1) j -= k;
if(j < k) j += k;
}
}
void ntt(int *a, int n, int rev){
rader(a, n);
for(int h = 2; h <= n; h <<= 1){
int hh = h >> 1, wn = modpow(G, rev ? mod - 1 - (mod - 1) / h : (mod - 1) / h);
for(int i = 0; i < n; i += h){
ll w = 1;
for(int j = i; j < i + hh; j++){
int x = a[j], y = a[j + hh] * w % mod;
a[j] = (x + y) % mod, a[j + hh] = (x - y + mod) % mod;
w = w * wn % mod;
}
}
}
if(rev){
ll inv = modpow(n, mod - 2);
for(int i = 0; i < n; i++) a[i] = inv * a[i] % mod;
}
}
int tmp1[maxn], tmp2[maxn], tmp3[maxn], A[maxn], B[maxn];
void get_inv(int *a, int *b, int n){//used tmp1!!!
if(n == 1) return (void)(b[0] = modpow(a[0], mod - 2), b[1] = 0);
get_inv(a, b, n >> 1);
int t = n << 1;
for(int i = 0; i < n; i++) tmp1[i] = a[i];
for(int i = n; i < t; i++) tmp1[i] = b[i] = 0;
ntt(tmp1, t, 0), ntt(b, t, 0);
for(int i = 0; i < t; i++) b[i] = (2 - (ll)b[i] * tmp1[i] % mod + mod) * b[i] % mod;
ntt(b, t, 1);
for(int i = n; i < t; i++) b[i] = 0;
}
void get_ans(int *a, int *b, int n){//used tmp1,tmp2,tmp3!!!
if(n == 1) return (void)(b[0] = 1, b[1] = 0);
get_ans(a, b, n >> 1);
int t = n << 1;
for(int i = 0; i < n; i++) tmp1[i] = a[i], tmp2[i] = b[i];
for(int i = n; i < t; i++) tmp1[i] = tmp2[i] = 0;
ntt(tmp1, t, 0), ntt(tmp2, t, 0);
for(int i = 0; i < t; i++){
ll x = tmp1[i], y = tmp2[i];
tmp1[i] = (x * y % mod * y + mod - 1) % mod;
tmp2[i] = (2 * x * y - 1 + mod) % mod;
}
ntt(tmp1, t, 1), ntt(tmp2, t, 1);
for(int i = 0; i < n; i++) tmp3[i] = tmp1[i];
for(int i = n; i < t; i++) tmp3[i] = 0;
get_inv(tmp2, b, n);
ntt(tmp3, t, 0), ntt(b, t, 0);
for(int i = 0; i < t; i++) b[i] = (ll)b[i] * tmp3[i] % mod;
ntt(b, t, 1);
for(int i = n; i < t; i++) b[i] = 0;
}
const int maxr = 10000000;
char str[maxr], prt[maxr];
int rpos, ppos, mmx;
char readc(){
if(!rpos) mmx = fread(str, 1, maxr, stdin);
if(rpos == mmx) return 0;
char c = str[rpos++];
if(rpos == maxr) rpos = 0;
return c;
}
int read(){
int x; char c;
while((c = readc()) < '0' || c > '9');
x = c - '0';
while((c = readc()) >= '0' && c <= '9') x = x * 10 + c - '0';
return x;
}
int print(int x){
if(x){
static char sta[10];
int tp = 0;
for(; x; x /= 10) sta[tp++] = x % 10 + '0';
while(tp > 0) prt[ppos++] = sta[--tp];
} else prt[ppos++] = '0';
prt[ppos++] = '\n';
}
int main(){
int n = read(), m = read();
for(int i = 0; i < n; i++) ++A[read()];
int len = 1;
while(len <= m) len <<= 1;
get_ans(A, B, len);
for(int i = 1; i <= m; i++) print(B[i]);
fwrite(prt, 1, ppos, stdout);
return 0;
}