简介
接上一篇,实现矩阵相乘。
式子
矩阵相乘式子如下:
但考虑到运算量的问题,一个更通用的式子如下:
本文实现该通用式子
A矩阵为
B矩阵为
C矩阵为
乘法公式
实现
void matmul(const char *tr, int n, int k, int m, double alpha,
const double *A, const double *B, double beta, double *C)
{
double d;
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
d = 0.0;
for (int x = 0; x < m; x++) {
d += A[MAT_INDEX(n, m, i, x)] * B[MAT_INDEX(m, k, x, j)];
}
C[MAT_INDEX(n, k, i, j)]= alpha*d + beta*C[MAT_INDEX(n, k, i, j)];
}
}
}
通过乘法公式,对照MAT_INDEX宏定义,可以相当清晰地了解矩阵的规模以及下标关系,可读性相当高。
测试
int main() {
double *A = mat(2, 3);
for (int i = 0; i < 6; i++) {
A[i] = i;
}
printf("\r\n");
matprint(A, 2, 3);
double *B = mat(3, 4);
for (int i = 0; i < 12; i++) {
B[i] = i;
}
printf("\r\n");
matprint(B, 3, 4);
double *C = mat(2, 4);
matmul(2, 4, 3, 1.0, A, B, 0.0, C);
printf("\r\n");
matprint(C, 2, 4);
getchar();
return 1;
}
优化
在实际应用中,A和B一般包含转置运算,即i和j交换,可直接在上面的基础上进行实现。
定义N为不转置,T为转置
则有
A和A^T的元素关系为
因此
A[MAT_INDEX(n, m, i, x)]
对应的转置元素如为
A[MAT_INDEX(m, n, x, i)]
实现为
void matmul(const char *tr, int n, int k, int m, double alpha,
const double *A, const double *B, double beta, double *C)
{
double d;
int f = tr[0] == 'N' ? (tr[1] == 'N' ? 1 : 2) : (tr[1] == 'N' ? 3 : 4);
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
d = 0.0;
switch (f) {
case 1: for (int x = 0; x < m; x++) d += A[MAT_INDEX(n, m, i, x)] * B[MAT_INDEX(m, k, x, j)]; break;
case 2: for (int x = 0; x < m; x++) d += A[MAT_INDEX(n, m, i, x)] * B[MAT_INDEX(k, m, j, x)]; break;
case 3: for (int x = 0; x < m; x++) d += A[MAT_INDEX(m, n, x, i)] * B[MAT_INDEX(m, k, x, j)]; break;
case 4: for (int x = 0; x < m; x++) d += A[MAT_INDEX(m, n, x, i)] * B[MAT_INDEX(k, m, j, x)]; break;
}
C[MAT_INDEX(n, k, i, j)] = alpha*d + beta*C[MAT_INDEX(n, k, i, j)];
}
}
}