基本
一个 和 的矩阵相乘的时间复杂度为 ,得到的结果为 的矩阵。
由于矩阵乘法有结合律,因此对于一个 阶方阵 ,可以利用快速幂在 的时间内计算 。
Strassen 矩阵乘法
应用了分治的想法,把原来的 和 的矩阵相乘,优化成了 7 个 的矩阵相乘。
由 Master Theorem, 的解为 。
算法细节参见其他相关材料。
循环矩阵乘法
循环矩阵在计算一类和马尔可夫链有关的概率问题时会被用到。
阶循环矩阵是满足下面条件的
阶方阵:每一行都由上一行循环右移一位得到。如
阶循环矩阵:
可以证明,对于两个 阶循环矩阵 ,满足 都是循环矩阵。
对于循环矩阵而言,由于其每一行都是相同的,因此只需要一行就可以保存整个矩阵的信息。这使得循环矩阵的乘法的时间复杂度可以从一般的 降低到 。
我们使用第一行代表矩阵。对于两个
阶循环矩阵
,它们的第一行为
(base-0),那么
的第一行
满足:
以行的形式保存时,用行向量来乘矩阵比较方便。设行向量为 ,要乘的矩阵是 ,那么结果向量 恰好等于 的第一行,其中 是以 为第一行的循环矩阵 同 相乘的结果。
列向量的话较为繁琐,要经过两次转化。
例:牛牛的粉丝
题意:一个环上有 个节点,第 个节点上有 个人。每一轮每一个人独立行动,有 概率不动, 概率顺时针移动到下一个点, 概率逆时针移动到下一个点。问 轮后每个点上人数的期望。(原题:牛客练习赛 68 D)
可以看出,本题所表示的随机过程是一个马尔可夫链,且概率转移矩阵是一个循环矩阵。因此可以用快速幂计算最终的期望。时间复杂度为 。
由于答案就是每一个点各自答案的线性组合,因此快速幂的“底”直接使用了 ,而不是单位矩阵 。
#include <bits/stdc++.h>
#define MOD 998244353
using namespace std;
typedef long long ll;
inline int modadd(int x, int y){
return (x + y >= MOD ? x + y - MOD: x + y);
}
int poww(int a, int b){
int res = 1;
while (b > 0){
if (b & 1) res = 1ll * res * a % MOD;
a = 1ll * a * a % MOD, b >>= 1;
}
return res;
}
int n, a, b, c, x[505];
int mat[505], tmp[505];
ll k;
void mul(int *u, int *v){
memset(tmp, 0, sizeof(tmp));
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j)
tmp[(i + j) % n] = modadd(tmp[(i + j) % n], 1ll * u[i] * v[j] % MOD);
for (int i = 0; i < n; ++i)
u[i] = tmp[i];
}
void init(){
scanf("%d%lld%d%d%d", &n, &k, &a, &b, &c);
for (int i = 0; i < n; ++i)
scanf("%d", &x[i]);
int s = a + b + c;
s = poww(s, MOD - 2);
mat[0] = 1ll * c * s % MOD;
mat[n - 1] = 1ll * b * s % MOD;
mat[1] = 1ll * a * s % MOD;
}
void solve(){
while (k > 0){
if (k & 1ll) mul(x, mat);
mul(mat, mat), k >>= 1;
}
for (int i = 0; i < n; ++i)
printf("%d%c", x[i], (i == n - 1 ? '\n': ' '));
}
int main(){
init();
solve();
return 0;
}
稀疏矩阵乘法
对于稀疏矩阵,可以在更低的时间复杂度内完成乘法运算。
我们用三元组来记录矩阵内的非零元素,即 表示在矩阵第 行第 列的值为 。那么对于两个用二元组列表表示的稀疏矩阵,只要计算它们之间每一对三元组 的乘积,并将乘积累加到结果的对应位置即可。
下的矩阵乘法
可以用 bitset 进行加速。
对于两个位向量,它们的内积就是与运算之后向量中 1 的个数。
下的矩阵乘法
也可以用 bitset 进行加速。
可以用两个位向量表示一个 向量1,利用这种向量内积的快速计算来加速矩阵乘法。具体运算方法可以参考下面的代码。
例:HDU 4920
本题是一个 下的矩阵乘法模板题,实现如下。
#include <bits/stdc++.h>
using namespace std;
int n, a[805][805];
bitset<805> r1[805], r2[805];
bitset<805> c1[805], c2[805];
bitset<805> tmp[4];
void init(){
for (int i = 0; i < n; ++i){
r1[i].reset();
r2[i].reset();
for (int j = 0, t; j < n; ++j){
scanf("%d", &t);
t %= 3;
if (t == 1) r1[i][j] = 1;
else if (t == 2) r2[i][j] = 1;
}
}
for (int i = 0; i < n; ++i)
for (int j = 0, t; j < n; ++j)
scanf("%d", &t), a[i][j] = t % 3;
for (int j = 0; j < n; ++j){
c1[j].reset();
c2[j].reset();
for (int i = 0; i < n; ++i){
if (a[i][j] == 1) c1[j][i] = 1;
else if (a[i][j] == 2) c2[j][i] = 1;
}
}
}
void solve(){
for (int i = 0; i < n; ++i)
for (int j = 0; j < n; ++j){
tmp[0] = r1[i] & c2[j];
tmp[1] = r2[i] & c1[j];
tmp[2] = r1[i] & c1[j];
tmp[3] = r2[i] & c2[j];
tmp[0] |= tmp[1];
tmp[2] |= tmp[3];
int res = (2 * tmp[0].count() + tmp[2].count()) % 3;
printf("%d%c", res, (j == n - 1 ? '\n': ' '));
}
}
int main(){
while (scanf("%d", &n) == 1){
init();
solve();
}
return 0;
}