NTT算法中素数的选取

     设这个素数为N,N-1必须要可以被小于n的2的次方整除,因为经常需要除以2的次方长度,必须要可以被整除,所以N可以表示为m*2^k+1的长度,后面我试了一下其他2^k>=ntt长度的素数,发现结果并不正确,后来才发现N还必须大于最大卷积结果N*max{a[i]}*{b[i]},否则mod就会产生错误.后面我又看了一下本原元的存在性证明,那个大佬用一句话就证明了,说是什么群论里的基础知识.没看懂顿时觉得自己对数论很小白.感觉编码课上好像学过什么本原元的,但是没有要证明什么的.好像只有一些结论没有过程.

    不过我觉得还是FFT适用的场景更多,虽然精度上有一点问题.

    贴一个两三个月前写的NTT题,路过需要的可以参考一下(hihocoder 1388,内含fft板子)

#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cmath>
#pragma warning(disable:4996)
using namespace std;
const int g = 3;
const long long int mymod = 31525197391593473LL;
int revg;
const double pi = acos(-1.0);
const int bign = 100033;
struct complex
{
	double r, v;
	complex()
	{ 
		r = 0.0;
		v = 0.0;
	}


	complex(double r1, double v1)
	{
		r = r1;
		v = v1;
	}
	complex operator+(const complex& ano)
	{
		return complex(r + ano.r,  v + ano.v);
	}
	complex operator-(const complex& ano)
	{
		return complex(r - ano.r, v - ano.v);
	}
	complex operator*(const complex& ano)
	{
		return complex(r * ano.r - v * ano.v, r * ano.v + v * ano.r);
	}
	complex operator/(double r1)
	{
		return complex(r / r1, v / r1);
	}
}a[4*bign],b[4*bign];


long long int mymul(long long int ta, long long int tb)
{
	long long int res = 0;
	for (;tb; tb >>= 1)
	{
		if (tb & 1)
			res = (res + ta) % mymod;
		ta = (ta << 1) % mymod;
	}
	return res;
}


long long mul(long long x, long long y) {
	return (x * y - (long long)(x / (long double)mymod * y + 1e-3) * mymod + mymod) % mymod;
}
long long int quickpow(long long ta, long long int tb)
{
	long long res = 1;
	for (; tb; tb >>= 1)
	{
		if (tb & 1)
			res = mul(res, ta);
		ta = mul(ta, ta);
	}
	return res;
}
int rev(int num, int len)
{
	int ans = 0;
	for (int i = 1,j = 0; j<len; i<<=1,j++)
	{
		if (i & num)
		{
			ans += (1<<(len - 1 - j));
		}
	}
	return ans;
}


void NTT(long long int c[], int len, int on)
{
	int tlen = (1 << len);


	for (int i = 0; i < tlen; i++)
	{
		int j = rev(i, len);
		if (j > i)
			swap(c[i], c[j]);
	}


	//long long int tg = g;
	//if (on == -1)
	//	tg = quickpow(g, mymod - 2);
	for (int i = 2; i <= tlen; i <<= 1)
	{
		long long int wn = quickpow(g,(mymod-1)/i);
		if (on == -1)
			wn = quickpow(wn, mymod-2);
		for (int j = 0; j < tlen; j += i)
		{
			long long int w = 1;
			for (int k = j; k < j + i / 2; k++)
			{
				long long int u = c[k];
				long long int v = mul(w , c[k + i / 2]);
				c[k] = (u + v) % mymod;
				c[k + i / 2] = (u - v + mymod) % mymod;
			//	long long int oldw = w;
				w = mul(w, wn);
			}
		}
	}
	if (on == -1)
	{
		//for (int i = 0; i < tlen / 2; i++)
		//{
		//	swap(c[i], c[tlen - i - 1]);
		//}
		for (int i = 0; i < tlen; i++)
			c[i] = mul(c[i], quickpow(tlen, mymod - 2));
	}
}


void FFT(complex c[], int len, int on)
{	
	int tlen = (1 << len);
	
	for (int i = 0; i < tlen; i++)
	{
		int j = rev(i,len);
		if (j > i)
			swap(c[i], c[j]);
	}


	for (int i = 2; i <= tlen; i <<= 1)
	{
		complex wn(cos(2*pi/i),on * sin(2*pi/i));
		for (int j = 0; j < tlen; j += i)
		{
			complex w(1,0);
			for (int k = j; k < j + i/2; k++)
			{
				complex u =  c[k];
				complex v = (w * c[k + i / 2]);
				c[k] = u + v;
				c[k + i / 2] = u - v;
				w = w * wn;
			}
		}
	}
	if (on == -1)
	{
		for (int i = 0; i < tlen; i++)
		{
			c[i].r /= tlen;
			c[i].v /= tlen;
		}
	}
}
long long int a1[4 * bign], b1[4 * bign];
int main()
{
	int T;
	//printf("%lld\n", quickpow(g, mymod - 1));
	scanf("%d", &T);
	
	while (T--)
	{
		int n;
		long long int ans = 0;
		scanf("%d", &n);
		for (int i = 0; i < n; i++)
		{
			int x;
			scanf("%d", &x);
			ans += 1ll * x * x;
			a1[i] = x;
			a1[i + n] = a1[i];
		}
		for (int i = n - 1; i >= 0; i--)
		{
			int x;
			scanf("%d", &x);
			ans += 1ll * x * x;
			b1[i] = x;
		}
		int len = 0;
		int j;
		for (j = 1; j < 2 * n; j <<= 1,len++);
		for (int i = n; i < j; i++)
			b1[i] = 0;
		for (int i = 2 * n; i < j; i++)
			a1[i] = 0;
		NTT(a1, len, 1);
		NTT(b1, len, 1);
		for (int i = 0; i < j; i++)
			a1[i] = mul(a1[i],b1[i])%mymod;
		NTT(a1, len, -1);
		long long int tmax = 0;
		for (int i = n - 1; i < 2 * n - 1; i++)
		{
			long long int tmp = a1[i];
			if (tmp > tmax)
				tmax = tmp;
		}
		printf("%lld\n", ans - 2* tmax);
	}
}


/*
5
5
1000000 1000001 1000002 1000003 1000004
1000003 1000004 1000000 1000001 1000002
*/

猜你喜欢

转载自blog.csdn.net/dx888888/article/details/80281545
NTT