题目链接:LOJ
题目描述:输入多项式的次数$n$,一个整数$m$和$f(0),f(1),f(2),\ldots,f(n)$,输出$f(m),f(m+1),f(m+2),\ldots,f(m+n)$
数据范围:$1\leq n\leq 10^5,n<m\leq 10^8$
一道披着拉格朗日插值模板的外衣的数论题。。。
$$f(m+i)=\sum_{j=0}^nf(j)\dfrac{\prod_{k\not=j}(m+i-k)}{\prod_{k\not=j}(j-k)}$$
$$=\dfrac{(m+i)!}{(m-n+i-1)!}\sum_{j=0}^n\dfrac{f(j)*(-1)^{n-j}}{j!*(n-j)!}*\frac{1}{m+i-j}$$
预处理阶乘,逆元,阶乘逆元,和$mfac[i]=\prod_{j=0}^{i-1}(m-n+j)$,$m-n+i$的逆元和$mfac[i]$的逆元。
使用NTT计算,时间复杂度为$O(n\log n)$
1 #include<bits/stdc++.h> 2 #define Rint register int 3 using namespace std; 4 typedef long long LL; 5 const int N = 1 << 19, mod = 998244353, G = 3, Gi = 332748118; 6 inline int kasumi(int a, int b){ 7 int res = 1; 8 while(b){ 9 if(b & 1) res = (LL) res * a % mod; 10 a = (LL) a * a % mod; 11 b >>= 1; 12 } 13 return res; 14 } 15 int rev[N]; 16 inline int calrev(int len){ 17 int L = -1, limit = 1; 18 while(limit <= len){limit <<= 1; L ++;} 19 for(Rint i = 0;i < limit;i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << L); 20 return limit; 21 } 22 inline void NTT(int *A, int limit, int type){ 23 for(Rint i = 0;i < limit;i ++) if(i < rev[i]) swap(A[i], A[rev[i]]); 24 for(Rint mid = 1;mid < limit;mid <<= 1){ 25 int Wn = kasumi(type == 1 ? G : Gi, (mod - 1) / (mid << 1)); 26 for(Rint j = 0;j < limit;j += mid << 1){ 27 int w = 1; 28 for(Rint k = 0;k < mid;k ++, w = (LL) w * Wn % mod){ 29 int x = A[j + k], y = (LL) w * A[j + k + mid] % mod; 30 A[j + k] = (x + y) % mod; 31 A[j + k + mid] = (x - y + mod) % mod; 32 } 33 } 34 } 35 if(type == -1){ 36 int inv = kasumi(limit, mod - 2); 37 for(Rint i = 0;i < limit;i ++) A[i] = (LL) A[i] * inv % mod; 38 } 39 } 40 int fac[N], invfac[N], inv[N], mfac[N], minvfac[N], minv[N]; 41 inline void init(int n, int m){ 42 fac[0] = mfac[0] = 1; 43 for(Rint i = 1;i <= (n << 1 | 1);i ++){ 44 fac[i] = (LL) fac[i - 1] * i % mod; 45 mfac[i] = (LL) mfac[i - 1] * (m - n + i - 1) % mod; 46 } 47 invfac[n << 1 | 1] = kasumi(fac[n << 1 | 1], mod - 2); 48 minvfac[n << 1 | 1] = kasumi(mfac[n << 1 | 1], mod - 2); 49 for(Rint i = (n << 1 | 1);i;i --){ 50 invfac[i - 1] = (LL) invfac[i] * i % mod; 51 minvfac[i - 1] = (LL) minvfac[i] * (m - n + i - 1) % mod; 52 inv[i] = (LL) invfac[i] * fac[i - 1] % mod; 53 minv[i] = (LL) minvfac[i] * mfac[i - 1] % mod; 54 } 55 minv[0] = 1; 56 } 57 int n, m, f[N], A[N], B[N]; 58 int main(){ 59 scanf("%d%d", &n, &m); 60 for(Rint i = 0;i <= n;i ++) scanf("%d", f + i); 61 init(n, m); 62 int limit = calrev(n * 3); 63 for(Rint i = 0;i <= n;i ++){ 64 A[i] = (LL) f[i] * invfac[i] % mod * invfac[n - i] % mod; 65 if(n - i & 1) A[i] = mod - A[i]; 66 } 67 for(Rint i = 0;i <= (n << 1);i ++) B[i] = minv[i + 1]; 68 NTT(A, limit, 1); NTT(B, limit, 1); 69 for(Rint i = 0;i < limit;i ++) A[i] = (LL) A[i] * B[i] % mod; 70 NTT(A, limit, -1); 71 for(Rint i = n;i <= (n << 1);i ++) 72 printf("%d ", (LL) mfac[i + 1] * minvfac[i - n] % mod * A[i] % mod); 73 }