题目:https://www.luogu.org/problemnew/show/P4238
方法:https://www.cnblogs.com/TimelyRain/p/10010233.html
https://www.cnblogs.com/xiefengze1/p/9107752.html
感觉下面那个博客的转化方法更简单?
设 \( B^{'}(x) \) 为 \( A(x) \) 在模 \( x^{\left\lceil n/2 \right\rceil} \) 意义下的逆元,\( B(x) \) 为 \( A(x) \) 在模 \( x^{n} \) 意义下的逆元。
\( A(x)*B^{'}(x) \equiv 1 (mod x^{\left\lceil n/2 \right\rceil}) \)
\( A(x)*B^{'}(x) - 1 \equiv 0 (mod x^{\left\lceil n/2 \right\rceil}) \)
\( A^{2}(x)*B^{'2}(x) - 2*A(x)*B^{'}(x) + 1 \equiv 0 (mod x^{n}) \)
\( 2*A(x)*B^{'}(x) - A^{2}(x)*B^{'2}(x) - 1 \equiv 0 (mod x^{n}) \)
又有\( A(x)*B(x) \equiv 1 (mod x^{n}) \)
\( A(x)*B(x) - 1 \equiv 0 (mod x^{n}) \)
则 \( A(x)*B(x) \equiv 2*A(x)*B^{'}(x) - A^{2}(x)*B^{'2}(x) (mod x^{n}) \)
\( B(x) \equiv 2*B^{'}(x) - A(x)*B^{'2}(x) (mod x^{n}) \)
那个地方之所以可以两边平方,是因为卷积的时候新出来的 n/2 项的每一项的求和式子里相乘的两项总有至少一项的次数是 <= \( \left\lceil n/2 \right\rceil \)的;这样的项因为模 \( x^{\left\lceil n/2 \right\rceil} \)是0,所以系数是0,所以平方后整个式子还同余于0。
也就是说不是 \( \left\lceil n/2 \right\rceil \),其他的也行,只要能满足平方后还是0。
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const int N=1e5+5,M=N<<2,mod=998244353; int a[M],b[M],tp[M],len,r[M]; int rdn() { int ret=0;bool fx=1;char ch=getchar(); while(ch>'9'||ch<'0'){if(ch=='-')fx=0;ch=getchar();} while(ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar(); return fx?ret:-ret; } void upd(int &x){x>=mod?x-=mod:0;} int pw(int x,int k) {int ret=1;while(k){if(k&1)ret=(ll)ret*x%mod;x=(ll)x*x%mod;k>>=1;}return ret;} void ntt(int *a,bool fx) { for(int i=0;i<len;i++) if(i<r[i])swap(a[i],a[r[i]]); for(int R=2;R<=len;R<<=1) { int Wn=pw(3,(mod-1)/R); fx?Wn=pw(Wn,mod-2):0;/// for(int i=0,m=R>>1;i<len;i+=R)//+=R... for(int j=0,w=1;j<m;j++,w=(ll)w*Wn%mod) { int x=a[i+j],y=(ll)w*a[i+m+j]%mod; a[i+j]=x+y; a[i+m+j]=x+mod-y; upd(a[i+j]); upd(a[i+m+j]); } } if(!fx)return; int inv=pw(len,mod-2);/// for(int i=0;i<len;i++)a[i]=(ll)a[i]*inv%mod; } void solve(int n) { if(n==1){b[0]=pw(a[0],mod-2);return;} solve(n+1>>1); for(len=1;len<=n<<1;len<<=1); //n<<1 not n+1<<1//has no influence for(int i=0;i<len;i++)r[i]=(r[i>>1]>>1)+((i&1)?len>>1:0); for(int i=0;i<n;i++)tp[i]=a[i]; /////not a[i] for(int i=n;i<len;i++)tp[i]=0; ntt(tp,0); ntt(b,0); for(int i=0;i<len;i++) b[i]=((b[i]<<1)-(ll)tp[i]*b[i]%mod*b[i])%mod+mod,upd(b[i]); ntt(b,1); for(int i=n;i<len;i++)b[i]=0; } int main() { int n;n=rdn();for(int i=0;i<n;i++)a[i]=rdn(); solve(n); for(int i=0;i<n;i++)printf("%d ",b[i]);puts(""); return 0; }