楼主最近在用blas库做矩阵的运算,用到了复数矩阵的乘法,然后去查cblas_cgemm()函数,可是转了一圈,没发现这个函数的举例说明,大都是cblas_sgemm()和cblas_dgemm()的例子,所以就自己根据官方文档去研究了下,终于研究明白了,记录下来,方便以后自己看, 希望对各位有一点点的帮助,废话不多说,上干货。
gemm求的是下面公式的值:
C=alpha*A*B+beta*C;//A,B,C分别为矩阵
cblas_cgemm()函数声明是这样的:
void cblas_cgemm(OPENBLAS_CONST enum CBLAS_ORDER Order, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransA, OPENBLAS_CONST enum CBLAS_TRANSPOSE TransB, OPENBLAS_CONST blasint M, OPENBLAS_CONST blasint N, OPENBLAS_CONST blasint K,
OPENBLAS_CONST float *alpha, OPENBLAS_CONST float *A, OPENBLAS_CONST blasint lda, OPENBLAS_CONST float *B, OPENBLAS_CONST blasint ldb, OPENBLAS_CONST float *beta, float *C, OPENBLAS_CONST blasint ldc);
好长一串,我们简化一下,cblas_cgemm(Order,TransA,TransB,M,N,K,alpha,A,lda,B,ldb,beta,C,ldc);//对应的参数类型与上面对比
所有这些数据是存放在一维指针数组中的,特别说明,此时的存放规则是:指针的奇数位存放实部,指针的偶数位存放虚部,也就是按照复数顺序存放每个复数的实部和虚部。
参数说明:
Order:说明矩阵是按行还是按列存放在指针数组中的,值有CblasRowMajor,CblasColMajor
TransA:说明A矩阵在乘之前是否转置,值有CblasNoTrans,CblasTrans
TransB:同TransA
M:矩阵A的行数,矩阵C的行数
N:矩阵B的列数,矩阵C的列数
K:矩阵A的列数,矩阵B的行数
alpha:矩阵A*B的系数
beta:矩阵C的系数(如果计算C=A*B,则alpha[2]={1.0,1.0},beta[2]={0.0,0.0})
A,B,C:矩阵A,B,C
lda,ldb,ldc:矩阵A,B,C的列数
计算出来的A*B存放在C中。
特别说明,此时的C中的实部和虚部并不是复数最后结果的实部和虚部,真正的实部应该是:本复数的(虚部+实部)/2;真正的虚部应该是:本复数的(虚部-实部)/2。经过计算才得到最终的实部和虚部。
图1是MATLAB计算的结果,图2是cblas的计算结果。
最后附上我的C代码,请大家参考。
-
#include <cblas.h>
-
#include <stdio.h>
-
void matcomplexMul(float *A,int row_a,int colum_a,float *B,int colum_b,float *C)
-
{
-
int i,j;
-
int M=row_a;// A row, C row
-
int N=colum_b;// B column,C column
-
int K=colum_a;//A column, B row
-
float alpha[2]={1.0,1.0};
-
float beta[2]={0.0,0.0};
-
int lda = K;// A column
-
int ldb = N;//B column
-
int ldc = N;//C column
-
printf("the Row matrix multiple...\n");
-
cblas_cgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
-
}
-
void main()
-
{
-
int i = 0,j=0;
-
float A[12] = { 1.0, 2.0, 1.0, -3.0, 4.0, -1.0 , 1.0, 2.0, 1.0, -3.0, 4.0, -1.0 };
-
float B[12] = { 1.0, 2.0, 1.0, -3.0, 4.0, -1.0 , 1.0, 2.0, 1.0, -3.0, 4.0, -1.0 };
-
float C[18];
-
int row_A=3;
-
int colum_A=2;
-
int colum_B=3;
-
printf("the matrix A is:\n");
-
for (i = 0; i<row_A; i++)//3
-
{
-
for(j=0;j<colum_A;j++)//2
-
{
-
printf("%f+%f*I\t",A[(i*colum_A+j)*2],A[(i*colum_A+j)*2+1]);
-
}
-
printf("\n");
-
}
-
printf("\n");
-
printf("the matrix B is:\n");
-
for (i = 0; i<colum_A; i++)
-
{
-
for(j=0;j<colum_B;j++)
-
{
-
printf("%f+%f*I\t",B[(i*colum_B+j)*2],B[(i*colum_B+j)*2+1]);
-
}
-
printf("\n");
-
}
-
printf("\n");
-
matcomplexMul(A,3,2,B,3,C);
-
for (i = 0; i<row_A; i++)
-
{
-
for(j=0;j<colum_B;j++)
-
{
-
printf("%f+%f*I\t",(C[(i*colum_B+j)*2+1]+C[(i*colum_B+j)*2])/2,(C[(i*colum_B+j)*2+1]-C[(i*colum_B+j)*2])/2);
-
}
-
printf("\n");
-
}
-
printf("\n");
-
}