- 算法说明
strassen矩阵乘法是将每个矩阵分成4块,分别按照公式计算,其中再遇到矩阵乘法时,递归调用。公式如下:
- 源代码
#include <cstdio>
const int n = 2;
void strassen(int **a, int **b, int **c, int m) {
if(m == 1) {
c[0][0] = a[0][0] * b[0][0];
return;
}
int **a1, **a2, **a3, **a4, **b1, **b2, **b3, **b4, **c1, **c2, **c3, **c4;
int **m1, **m2, **m3, **m4, **m5, **m6, **m7, mid = n / 2;
a1 = new int * [mid], a2 = new int * [mid], a3 = new int * [mid], a4 = new int * [mid];
b1 = new int * [mid], b2 = new int * [mid], b3 = new int * [mid], b4 = new int * [mid];
c1 = new int * [mid], c2 = new int * [mid], c3 = new int * [mid], c4 = new int * [mid];
m1 = new int * [mid], m2 = new int * [mid], m3 = new int * [mid], m4 = new int * [mid], m5 = new int * [mid], m6 = new int * [mid], m7 = new int * [mid];
for(int i = 0; i < n; i++) {
a1[i] = new int[mid], a2[i] = new int[mid], a3[i] = new int[mid], a4[i] = new int[mid];
b1[i] = new int[mid], b2[i] = new int[mid], b3[i] = new int[mid], b4[i] = new int[mid];
c1[i] = new int[mid], c2[i] = new int[mid], c3[i] = new int[mid], c4[i] = new int[mid];
m1[i] = new int[mid], m2[i] = new int[mid], m3[i] = new int[mid], m4[i] = new int[mid], m5[i] = new int[mid], m6[i] = new int[mid], m7[i] = new int[mid];
}
for(int i = 0; i < mid; i++) {
for(int j = 0; j < mid; j++) {
a1[i][j] = a[i][j];
a2[i][j] = a[i + mid][j];
a3[i][j] = a[i][j + mid];
a4[i][j] = a[i + mid][j + mid];
b1[i][j] = b[i][j];
b2[i][j] = b[i + mid][j];
b3[i][j] = b[i][j + mid];
b4[i][j] = b[i + mid][j + mid];
}
}
for(int i = 0; i < mid; i++) {
for(int j = 0; j < mid; j++) {
m1[i][j] = a1[i][j] + a4[i][j];
m2[i][j] = b1[i][j] + b4[i][j];
m3[i][j] = a3[i][j] + a4[i][j];
m4[i][j] = b2[i][j] - b4[i][j];
m5[i][j] = b3[i][j] - b1[i][j];
m6[i][j] = a1[i][j] + a2[i][j];
m7[i][j] = a3[i][j] - a1[i][j];
c1[i][j] = b1[i][j] + b2[i][j];
c2[i][j] = a2[i][j] - a4[i][j];
c3[i][j] = b3[i][j] + b4[i][j];
}
}
strassen(m1, m2, m1, mid);
strassen(m3, b1, m2, mid);
strassen(a1, m4, m3, mid);
strassen(a4, m5, m4, mid);
strassen(m6, b4, m5, mid);
strassen(m7, c1, m6, mid);
strassen(c2, c3, m7, mid);
for(int i = 0; i < mid; i++) {
for(int j = 0; j < mid; j++) {
c1[i][j] = m1[i][j] + m4[i][j] - m5[i][j] + m7[i][j];
c2[i][j] = m3[i][j] + m5[i][j];
c3[i][j] = m2[i][j] + m4[i][j];
c4[i][j] = m1[i][j] + m3[i][j] - m2[i][j] + m6[i][j];
}
}
for(int i = 0; i < mid; i++) {
for(int j = 0; j < mid; j++) {
c[i][j] = c1[i][j];
c[i + mid][j] = c2[i][j];
c[i][j + mid] = c3[i][j];
c[i + mid][j + mid] = c4[i][j];
}
}
delete a1, a2, a3, a4, b1, b2, b3, b4, c1, c2, c3, c4;
delete m1, m2, m3, m4, m5, m6, m7;
}
int main() {
int **a, **b, **c, tempNum = 1;
a = new int * [n], b = new int * [n], c = new int * [n];
for(int i = 0; i < n; i++) {
a[i] = new int[n];
b[i] = new int[n];
c[i] = new int[n];
}
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
a[i][j] = tempNum;
b[i][j] = tempNum++ + 4;
}
}
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
printf("%d ", a[i][j]);
}
printf("\n");
}
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
printf("%d ", b[i][j]);
}
printf("\n");
}
strassen(a, b, c, n);
for(int i = 0; i < n; i++) {
for(int j = 0; j < n; j++) {
printf("%d ", c[i][j]);
}
printf("\n");
}
delete a, b, c;
return 0;
}
- 运行结果