问题:
设
和
是两个
阶矩阵,求它们的乘积矩阵C。这里,假设
是
的幂次方。
一、问题分析(模型、算法设计和正确性证明等)
实验要求使用分治法解决n阶矩阵(n是2的幂次方)相乘问题,因为n是2的幂次方,可以使用朴素分块矩阵乘法或者 Strassen 法,这里两种都尝试一下,顺便连蛮力法也放进去。
二、复杂度分析
蛮力法伪代码:
for i = 1 to n do:
for j = 1 to n do:
for k = 1 to n do:
C[i][j] = C[i][j] + A[i][k]・B[k][j]
显然时间复杂度为 .
朴素分块矩阵乘法伪代码:
Divide_And_Conquer(int[][]A,int[][]B,int n){
int [][]C = new int[n][n]; //定义一个新矩阵存放结果
if n==1:
C11=A11*B11;
else Divide A, B and C as in 4 equation:
n /= 2;
C11=Divide_And_Conquer(A11,B11,n) + Divide_And_Conquer(A12,B21,n);
C22=Divide_And_Conquer(A11,B12,n) + Divide_And_Conquer(A12,B22,n);
C21=Divide_And_Conquer(A21,B11,n) + Divide_And_Conquer(A22,B21,n);
C22=Divide_And_Conquer(A21,B22,n) + Divide_And_Conquer(A22,B22,n);
return C;
}
因为n/2 * n/2 的矩阵乘法进行了8次, n/2 * n/2的矩阵加法进行了4次所以复杂度为:
Strassen 法伪代码:
Strassen_DAC(int [][]A, int [][]B,int n){
int [][]C = new int[n][n]; //定义结果矩阵
if n==1:
C11 = A11*B11;
else Divide A, B, and C as in 4 equation:
n /= 2;
int [][]M1,M2,M3,M4,M5,M6,M7 = new int[n][n];
M1 = Strassen_DAC(A11, B12-B22, n);
M2 = Strassen_DAC(A11+A12, B22, n);
M3 = Strassen_DAC(A21+A22, B11, n);
M4 = Strassen_DAC(A22, B21-B11, n);
M5 = Strassen_DAC(A11+A12, B11+B12, n);
M6 = Strassen_DAC(A12-A22, B21+B22, n);
M7 = Strassen_DAC(A11-A21, B11+B12, n);
C11 = M5 + M4 - M2 + M6;
C12 = M1 + M2;
C21 = M3 + M4;
C22 = M5 + M1 -M3 -M7;
return C;
}
从伪代码明显可以看出,程序执行了7次n/2 * n/2的矩阵乘法,以及 18次n/2 *n/2的矩阵加减运算,所以复杂度为:
三、程序实现和测试过程和结果(主要描述出现的问题和解决方法)
算法的思路倒是不难,难的是具体实现的时候,矩阵的分块操作,容易绕不清楚。而且其中还有矩阵的加减运算,得出的结果最后还要把 合并成为一个矩阵 ,这些在伪代码里都没有给出来,但是复杂繁琐容易出bug的正是这些细节。
实验中使用Java语言编写将三种方法放到同一个类中,一下为类中的各个方法:
源码:
package root;
/**
* @author 宇智波Akali
* 这是三种矩阵乘法
* @date 2020.3.18
*/
public class Try {
//创建一个随机数构成的n*n矩阵
public static int[][] initializationMatrix(int n){
int[][] result = new int[n][n];//创建一个n*n矩阵
for(int i = 0;i < n;i++){
for(int j = 0;j < n;j++){
result[i][j] = (int)(Math.random()*10); //随机生成1~10之间的数
}
}
return result;
}
//蛮力法求矩阵相乘
public static int[][] BruteForce(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
result[i][j] = 0;
for(int k=0;k<n;k++){
result[i][j] += p[i][k]*q[k][j];
}
}
}
return result;
}
//分治法求矩阵相乘
public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];//创建一个n*n矩阵
//当n为2时,用蛮力法求矩阵相乘,返回结果结果
if(n == 2){
result = BruteForce(p,q,n);
return result;
}
//当n大于3时,采用分治法,递归求最终结果
if(n > 2){
int m = n/2;
//将矩阵p分成四块
int[][] p1 = QuarterMatrix(p,n,1);
int[][] p2 = QuarterMatrix(p,n,2);
int[][] p3 = QuarterMatrix(p,n,3);
int[][] p4 = QuarterMatrix(p,n,4);
//将矩阵q分成四块
int[][] q1 = QuarterMatrix(q,n,1);
int[][] q2 = QuarterMatrix(q,n,2);
int[][] q3 = QuarterMatrix(q,n,3);
int[][] q4 = QuarterMatrix(q,n,4);
//将结果矩阵分成同等大小的四块
int[][] result1 = QuarterMatrix(result,n,1);
int[][] result2 = QuarterMatrix(result,n,2);
int[][] result3 = QuarterMatrix(result,n,3);
int[][] result4 = QuarterMatrix(result,n,4);
//最关键的步骤,递归调用DivideAndConquer()函数,并用公式相加
result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);//y=ae+bg
result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);//s=af+bh
result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);//t=ce+dg
result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);//u=cf+dh
//合并,将四块小矩阵合成整体
result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四个小矩阵合并成一个大矩阵
}
return result;
}
//strassen法
public static int[][] Strassen(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];//创建一个n*n矩阵
if( n == 2){
result = BruteForce(p,q,n);
return result;
}
int m = n/2;
//将矩阵p分成四块
int[][] p1 = QuarterMatrix(p,n,1);
int[][] p2 = QuarterMatrix(p,n,2);
int[][] p3 = QuarterMatrix(p,n,3);
int[][] p4 = QuarterMatrix(p,n,4);
//将矩阵q分成四块
int[][] q1 = QuarterMatrix(q,n,1);
int[][] q2 = QuarterMatrix(q,n,2);
int[][] q3 = QuarterMatrix(q,n,3);
int[][] q4 = QuarterMatrix(q,n,4);
int[][] m1 = DivideAndConquer(AddMatrix(p1,p4,m),AddMatrix(q1,q4,m),m);
int[][] m2 = Strassen(AddMatrix(p3,p4,m),q1,m);
int[][] m3 = Strassen(p1,ReduceMatrix(q2,q4,m),m);
int[][] m4 = Strassen(p4,ReduceMatrix(q3,q1,m),m);
int[][] m5 = Strassen(AddMatrix(p1,p2,m),q4,m);
int[][] m6 = Strassen(ReduceMatrix(p3,p1,m),AddMatrix(q1,q2,m),m);
int[][] m7 = Strassen(ReduceMatrix(p2,p4,m),AddMatrix(q3,q4,m),m);
//将结果矩阵分成同等大小的四块
int[][] result1 = QuarterMatrix(result,n,1);
int[][] result2 = QuarterMatrix(result,n,2);
int[][] result3 = QuarterMatrix(result,n,3);
int[][] result4 = QuarterMatrix(result,n,4);
result1 = AddMatrix(ReduceMatrix(AddMatrix(m1,m4,m),m5,m),m7,m);
result2 = AddMatrix(m3,m5,m);
result3 = AddMatrix(m2,m4,m);
result4 = AddMatrix(AddMatrix(ReduceMatrix(m1,m2,m),m3,m),m6,m);
result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四个小矩阵合并成一个大矩阵
return result;
}
//获取矩阵的四分之一,number用来确定返回哪一个四分之一
public static int[][] QuarterMatrix(int[][] p,int n,int number){
int rows = n/2; //行数减半
int cols = n/2; //列数减半
int[][] result = new int[rows][cols];
switch(number){
//左上
case 1 :
{
for(int i=0;i<rows;i++)
for(int j=0;j<cols;j++)
result[i][j] = p[i][j];
break;
}
//右上
case 2 :
{
for(int i=0;i<rows;i++)
for(int j=0;j<n-cols;j++)
result[i][j] = p[i][j+cols];
break;
}
//左下
case 3 :
{
for(int i=0;i<n-rows;i++)
for(int j=0;j<cols;j++)
result[i][j] = p[i+rows][j];
break;
}
//右下
case 4 :
{
for(int i=0;i<n-rows;i++)
for(int j=0;j<n-cols;j++)
result[i][j] = p[i+rows][j+cols];
break;
}
default:
break;
}
return result;
}
//把均分为四分之一的矩阵,合成一个矩阵
public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
int[][] result = new int[2*n][2*n];
for(int i=0;i<2*n;i++){
for(int j=0;j<2*n;j++){
if(i<n){
if(j<n)
result[i][j] = a[i][j];
else
result[i][j] = b[i][j-n];
}else{
if(j<n)
result[i][j] = c[i-n][j];
else
result[i][j] = d[i-n][j-n];
}
}
}
return result;
}
//求两个矩阵相加结果
public static int[][] AddMatrix(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
result[i][j] = p[i][j]+q[i][j];
}
}
return result;
}
//求两个矩阵相减结果
public static int[][] ReduceMatrix(int[][] p,int[][] q,int n){
int[][] result = new int[n][n];
for(int i=0;i<n;i++){
for(int j=0;j<n;j++){
result[i][j] = p[i][j]-q[i][j];
}
}
return result;
}
//输出矩阵的函数
public static void PrintfMatrix(int[][] matrix,int n){
for(int i=0;i<n;i++){
for(int j=0;j<n;j++)
System.out.printf("% 4d",matrix[i][j]);
System.out.println();
}
}
public static void main(String args[]){
int[][] p = initializationMatrix(8);
int[][] q = initializationMatrix(8);
//输出生成的两个矩阵
System.out.println("p:");
PrintfMatrix(p,8);
System.out.println();
System.out.println("q:");
PrintfMatrix(q,8);
//输出分治法矩阵相乘后的结果
int[][] bru_result = BruteForce(p,q,8);//新建一个矩阵存放矩阵相乘结果
System.out.println();
System.out.println("\nA*B(蛮力法):");
PrintfMatrix(bru_result,8);
//输出分治法矩阵相乘后的结果
int[][] dac_result = DivideAndConquer(p,q,8);//新建一个矩阵存放矩阵相乘结果
System.out.println();
System.out.println("A*B(分治法):");
PrintfMatrix(dac_result,8);
//输出strassen法矩阵相乘后的结果
int[][] stra_result = Strassen(p,q,8);//新建一个矩阵存放矩阵相乘结果
System.out.println("\nA*B(strassen法):");
PrintfMatrix(stra_result,8);
}
}
运行结果: