加速K进制不进位加法卷积的类FWT方法

并不是什么新方法,老早就有了。

背景

二进制的异或FWT实际上在做这样一件事:
构造一个行列式不为0的Len阶方阵T,使得 T A T B = T C TA\cdot TB=TC

将矩阵乘法拆开看,这个矩阵满足: x , i , j      T ( x , i ) T ( x , j ) = T ( x , i j ) \forall {x,i,j}~~~~T(x,i)T(x,j)=T(x,i\oplus j)

只要单个位的数满足上述性质,我们不妨将T(x,y)构造成其每一位的T的乘积,显然也满足性质。
于是只需要找出一个2*2的矩阵便可以了。
对于二进制异或,这个矩阵可以是 [ [ 1 , 1 ] , [ 1 , 1 ] ] [[1,1],[1,-1]] ,也可以是 [ [ 1 , 1 ] , [ 1 , 1 ] ] [[1,-1],[1,1]]

原理

K进制下对应的矩阵或许有很多,但是我们可以利用范德蒙德矩阵快速构造出一个
[ 1 1 1 . . . 1 1 w k 1 w k 2 . . . w k k 1 1 w k 2 w k 4 . . . w k 2 ( k 1 ) . . . . . . . . . . . . . . . 1 w k k 1 w k 2 ( k 1 ) . . . w k ( k 1 ) ( k 1 ) ] \begin{bmatrix} 1& 1 & 1& ... & 1\\ 1& w_k^1& w_k^2& ... & w_k^{k - 1}\\ 1& w_k^2 & w_k^4& ... & w_k^{2(k - 1)}\\ ...& ...& ...& ...& ...\\ 1& w_k^{k - 1}& w_k^{2(k - 1)} & ... & w_k^{(k - 1)(k - 1)} \end{bmatrix}

其中 w k w_k 是K次单位根。由于是范德蒙德矩阵,因此是有逆的。

1 k [ 1 1 1 . . . 1 1 w k 1 w k 2 . . . w k ( k 1 ) 1 w k 2 w k 4 . . . w k 2 ( k 1 ) . . . . . . . . . . . . . . . 1 w k ( k 1 ) w k 2 ( k 1 ) . . . w k ( k 1 ) ( k 1 ) ] \frac{1}{k} \begin{bmatrix} 1& 1 & 1& ... & 1\\ 1& w_k^{-1}& w_k^{-2}& ... & w_k^{-(k - 1)}\\ 1& w_k^{-2} & w_k^{-4}& ... & w_k^{-2(k - 1)}\\ ...& ...& ...& ...& ...\\ 1& w_k^{-(k - 1)}& w_k^{-2(k - 1)} & ... & w_k^{-(k - 1)(k - 1)} \end{bmatrix}
其实就是单位根取倒数,然后除去K.

朴素算法

有了上述矩阵及其逆矩阵之后,我们可以暴力计算 T A TA , T B TB 来得到 T C TC ,再得到C.
时间复杂度 O ( n 2 k ) O(n^2k) ,与暴力无异。

类FWT加速

我们将原数一位位的变化成贡献到的位置,并将系数逐步乘上。
具体地,做完低i位后,设x=a+b,b是低i位,那么a[x]的值就是所有原高位与a相同的位置对所有低位与b相同的位置的贡献之和。转移就是变化原数一位,乘上相应系数,加到对应位置下一次迭代。

也可以理解为分治,推式子。
时间复杂度 O ( n K l o g K n ) O(nKlog_Kn)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int mo = 1e9 + 9, K = 11, G = 13, N = 1e5;
int n, k;
ll T[K][K], NT[K][K];
ll a[N], b[N], w1, c[N], ans[N];
ll ksm(ll x, ll y) {
	ll ret = 1; for (; y; y>>=1) {
		if (y & 1) ret = ret * x % mo;
		x = x * x % mo;
	}
	return ret;
}

int wi;
void fast_trans(ll T[K][K], ll *a) {
	static ll b[N];
	for(int m = 1; m < n; m *= k) {
		memset(b, 0, sizeof b);
		for(int i = 0; i < n; i += m * k) {
			for(int j = 0; j < m; j++) {
				for(int u = 0, wu = 0; u < k; u++, wu += m) {
					for(int v = 0, wv = 0; v < k; v++, wv += m) {
						b[i + j + wu] = (b[i + j + wu] + a[i + j + wv] * T[u][v]) % mo;
					}
				}
			}
		}
		memcpy(a, b, sizeof b);
	}
}

void trans(ll T[K][K], ll *a) {
	fast_trans(T, a);
	return;

	static ll res[N];
	memset(res, 0, sizeof res);
	for(int i = 0; i < n; i++) {
		for(int j = 0; j < n; j++) {
			ll z = 1, ti = i, tj = j;
			for(int w = 0; w < wi; w++) {
				z = z * T[ti % k][tj % k] % mo;
				ti /= k, tj /= k;
			}
			res[i] = (res[i] + z * a[j]) % mo;
		}
	} 
	memcpy(a, res, sizeof res);
}

int main() {
	freopen("a.in","r",stdin);
	freopen("b.out","w",stdout);
	cin>>n>>k;
	int y = n;
	while (y) {
		wi++, y/=k; 
	}
	wi--;
	for(int i = 0; i < n; i++) scanf("%d", &a[i]);
	for(int i = 0; i < n; i++) scanf("%d", &b[i]);
	w1 = ksm(G, (mo - 1) / k);
	for(int i = 0; i < k; i++) {
		ll z = 1, w = ksm(w1, i);
		for(int j = 0; j < k; j++, z = z * w % mo) {
			T[j][i] = z;
			NT[j][i] = ksm(T[j][i], mo - 2);
		}
	}
	trans(T, a);
	trans(T, b);
	for(int i = 0; i < n; i++) c[i] = a[i] * b[i] % mo;
	trans(NT, c);
	ll ny = ksm(k, (ll) wi * (mo - 2) % (mo - 1));
	for(int i = 0; i < n; i++) {
		c[i] = c[i] * ny % mo, c[i] = (c[i] + mo) % mo;
		printf("%d ",c[i]);
	}
	cout<<endl;
}

发布了266 篇原创文章 · 获赞 93 · 访问量 8万+

猜你喜欢

转载自blog.csdn.net/jokerwyt/article/details/100050235