前些天学了一堆多项式的算法,今天总结一下。
乘法
朴素的NTT就不说了,主要说一下三模数NTT。
三模数NTT,顾名思义,就是选取三个适合做NTT的模数,然后把它们用CRT合并起来得到的答案再去对我们要求的模数取模。
为了方便,这三个模数分别是998244353, 1004535809和469762049。它们的原根都是3,且它们减去1的值都有超过20个2的因子。
如果你觉得能用int128的话就请忽略下面所有的表述
但是,在用CRT合并的时候,我们遇到一个麻烦:这三个模数的乘积太大了。我们令三个模数分别为p1, p2, p3,对三个模数取模得到的答案分别为a1, a2, a3,不做取模的原本的答案为Ans,则有:
先用CRT合并前两个模数的答案,得到
设
那么
即
此时k, M, A我们都已经得到,就直接对我们要求的模数取模就好了,即
Code
inline void exntt(int *a, int *b)
{
int len = 1, bit = 0;
while (len < n + m) len <<= 1, bit++;
for (int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << bit - 1;
for (int i = 0; i < 3; i++) {
for (int j = 0; j < len; j++) c[j] = a[j], d[j] = b[j];
ntt(c, len, 1, p[i]);
ntt(d, len, 1, p[i]);
for (int j = 0; j < len; j++) ans[i][j] = 1ll * c[j] * d[j] % p[i];
ntt(ans[i], len, -1, p[i]);
}
for (int i = 0; i < n + m - 1; i++) {
ll A = (mul(1ll * ans[0][i] * p[1] % M, ksm(p[1] % p[0], p[0] - 2, p[0]), M) +
mul(1ll * ans[1][i] * p[0] % M, ksm(p[0] % p[1], p[1] - 2, p[1]), M)) % M;
ll k = ((ans[2][i] - A) % p[2] + p[2]) % p[2] * ksm(M % p[2], p[2] - 2, p[2]) % p[2];
printf("%lld ", ((k % mod) * (M % mod) % mod + A % mod) % mod);
}
}
逆元
多项式求逆是个好东西,基本上所有除了乘法外的东西都要用到求逆。
我们要求
假如我们已经求得
则有
把 移到左边然后平方得到
不好处理,所以同乘以一个A,即
我们就得到了
于是倍增求即可。
Code
inline void poly_inv(int *a, int *b, int n)
{
if (n == 1) {
b[0] = ksm(a[0], mod - 2);
return;
}
poly_inv(a, b, n + 1 >> 1);
int len = 1;
while (len < n << 1) len <<= 1;
for (int i = 0; i < n; i++) tmp[i] = a[i];
for (int i = n; i < len; i++) tmp[i] = 0;
ntt(tmp, len, 1);
ntt(b, len, 1);
for (int i = 0; i < len; i++)
b[i] = 1ll * b[i] * (2 - 1ll * b[i] * tmp[i] % mod + mod) % mod;
ntt(b, len, -1);
for (int i = n; i < len; i++) b[i] = 0;
}
开方
多项式开方的思路和求逆一样,都是倍增的利用上一次的答案求下一次的。
简单推一下式子吧:
这里出现了两组解,我们只保留
还是套路的平方
其实 就是
那么最后 就求出来了
Code
void poly_sqrt(int *a, int *b, int n)
{
if (n == 1) {
b[0] = 1; //一般情况下a[0] = 1
return;
}
poly_sqrt(a, b, n + 1 >> 1);
int len = 1;
while (len < n << 1) len <<= 1;
memset(c, 0, sizeof c);
poly_inv(b, c, n);
for (int i = 0; i < n; i++) tmp[i] = a[i];
for (int i = n; i < len; i++) tmp[i] = 0;
ntt(tmp, len, 1);
ntt(b, len, 1);
ntt(c, len, 1);
for (int i = 0; i < len; i++) b[i] = 1ll * (tmp[i] + 1ll * b[i] * b[i] % mod) * c[i] % mod * inv2 % mod;
ntt(b, len, -1);
for (int i = n; i < len; i++) b[i] = 0;
}
除法和取模
已经有了求逆,那么多项式除法和取模又是什么?和逆元有什么区别吗?
举个例子:
由小学数学得到
然而, 并不等于 ,所以 乘以 的逆元也并不等于 ,这就是除法和求逆的区别。
那除法怎么求呢?
把 都看成多项式,我们发现是取模的结果 产生了上述影响。
如果我们把被除数多项式( )、除数多项式( )和商多项式( )都翻转一下,就可以得到
我们发现,此时取模的结果 也被翻转成了100,而100的位数较高,不会产生导致除法和求逆结果不同的影响。
所以,我们只要用 去乘上 的逆元就可以得到 ,然后把 翻转回来就得到了 。
什么?你说取模?
你都求出了 ,你还不会用 来求 吗?
Code
inline void poly_div(int *a, int *b, int *d, int *r, int n, int m)
{
for (int i = 0; i < n; i++) a1[i] = a[n - i - 1];
for (int i = 0; i < m; i++) b1[i] = b[m - i - 1];
int l = n - m + 1;
poly_inv(b1, d, l);
int len = 1;
while (len < n + l) len <<= 1;
ntt(a1, len, 1);
ntt(d, len, 1);
for (int i = 0; i < len; i++) d[i] = 1ll * a1[i] * d[i] % mod;
ntt(d, len, -1);
for (int i = l; i < len; i++) d[i] = 0;
for (int i = 0; i < l >> 1; i++) swap(d[i], d[l - i - 1]);
for (int i = 0; i < m; i++) r[i] = b[i];
for (int i = 0; i < l; i++) b1[i] = d[i];
for (int i = l; i < m; i++) b1[i] = 0;
ntt(r, len, 1);
ntt(b1, len, 1);
for (int i = 0; i < len; i++) r[i] = 1ll * r[i] * b1[i] % mod;
ntt(r, len, -1);
for (int i = 0; i < n; i++) r[i] = (a[i] - r[i] + mod) % mod;
}
求ln
exp
这两个不想写了……留坑待填