矩阵乘法三种方法(蛮力法、分治法、strassen)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_41668995/article/details/88847334
package algorithms;

public class strassen {
    //创建一个随机数构成的nxn矩阵
    public static int[][] initializationMatrix(int n){
        int[][] result = new int[n][n];//创建一个nxn矩阵
        for(int i = 0;i < n;i++){
            for(int j = 0;j < n;j++){
                result[i][j] = (int)(Math.random()*10); //随机生成1~10之间的数
            }
        }           
        return result;           
    }
   
    //蛮力法求矩阵相乘
    public static int[][] BruteForce(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                result[i][j] = 0;
                for(int k=0;k<n;k++){
                    result[i][j] += p[i][k]*q[k][j];
                }
            }
        }               
        return result;
    }
   
    //分治法求矩阵相乘
    public static int[][] DivideAndConquer(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];//创建一个nxn矩阵
        //当n为2时,用蛮力法求矩阵相乘,返回结果结果
        if(n == 2){
            result = BruteForce(p,q,n);           
            return result;
        }
       
        //当n大于3时,采用分治法,递归求最终结果
        if(n > 2){
            int m = n/2;
            
            //将矩阵p分成四块
            int[][] p1 = QuarterMatrix(p,n,1);
            int[][] p2 = QuarterMatrix(p,n,2);
            int[][] p3 = QuarterMatrix(p,n,3);
            int[][] p4 = QuarterMatrix(p,n,4);
            
            //将矩阵q分成四块
            int[][] q1 = QuarterMatrix(q,n,1);
            int[][] q2 = QuarterMatrix(q,n,2);
            int[][] q3 = QuarterMatrix(q,n,3);
            int[][] q4 = QuarterMatrix(q,n,4);
            
            //将结果矩阵分成同等大小的四块
            int[][] result1 = QuarterMatrix(result,n,1);
            int[][] result2 = QuarterMatrix(result,n,2);
            int[][] result3 = QuarterMatrix(result,n,3);
            int[][] result4 = QuarterMatrix(result,n,4);
           
            //最关键的步骤,递归调用DivideAndConquer()函数,并用公式相加
            result1 = AddMatrix(DivideAndConquer(p1,q1,m),DivideAndConquer(p2,q3,m),m);//y=ae+bg
            result2 = AddMatrix(DivideAndConquer(p1,q2,m),DivideAndConquer(p2,q4,m),m);//s=af+bh
            result3 = AddMatrix(DivideAndConquer(p3,q1,m),DivideAndConquer(p4,q3,m),m);//t=ce+dg
            result4 = AddMatrix(DivideAndConquer(p3,q2,m),DivideAndConquer(p4,q4,m),m);//u=cf+dh
            
            //合并,将四块小矩阵合成整体
            result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四个小矩阵合并成一个大矩阵
        }
        return result;
    }
    
    //strassen法
    public static int[][] StrassenMethod(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];//创建一个nxn矩阵
        int m = n/2;
        
        //将矩阵p分成四块
        int[][] p1 = QuarterMatrix(p,n,1);
        int[][] p2 = QuarterMatrix(p,n,2);
        int[][] p3 = QuarterMatrix(p,n,3);
        int[][] p4 = QuarterMatrix(p,n,4);
        
        //将矩阵q分成四块
        int[][] q1 = QuarterMatrix(q,n,1);
        int[][] q2 = QuarterMatrix(q,n,2);
        int[][] q3 = QuarterMatrix(q,n,3);
        int[][] q4 = QuarterMatrix(q,n,4);
        
        int[][] m1 = DivideAndConquer(AddMatrix(p1,p4,m),AddMatrix(q1,q4,m),m);
        int[][] m2 = DivideAndConquer(AddMatrix(p3,p4,m),q1,m);
        int[][] m3 = DivideAndConquer(p1,ReduceMatrix(q2,q4,m),m);
        int[][] m4 = DivideAndConquer(p4,ReduceMatrix(q3,q1,m),m);
        int[][] m5 = DivideAndConquer(AddMatrix(p1,p2,m),q4,m);
        int[][] m6 = DivideAndConquer(ReduceMatrix(p3,p1,m),AddMatrix(q1,q2,m),m);
        int[][] m7 = DivideAndConquer(ReduceMatrix(p2,p4,m),AddMatrix(q3,q4,m),m);
        
        //将结果矩阵分成同等大小的四块
        int[][] result1 = QuarterMatrix(result,n,1);
        int[][] result2 = QuarterMatrix(result,n,2);
        int[][] result3 = QuarterMatrix(result,n,3);
        int[][] result4 = QuarterMatrix(result,n,4);
       
        result1 = AddMatrix(ReduceMatrix(AddMatrix(m1,m4,m),m5,m),m7,m);
        result2 = AddMatrix(m3,m5,m);
        result3 = AddMatrix(m2,m4,m);
        result4 = AddMatrix(AddMatrix(ReduceMatrix(m1,m2,m),m3,m),m6,m);
        
        result = TogetherMatrix(result1,result2,result3,result4,m);//把分成的四个小矩阵合并成一个大矩阵
        
        return result;
    }
    
    
    
    
    //获取矩阵的四分之一,number用来确定返回哪一个四分之一
    public static int[][] QuarterMatrix(int[][] p,int n,int number){
        int rows = n/2;   //行数减半
        int cols = n/2;   //列数减半
        int[][] result = new int[rows][cols];
        switch(number){
           //左上
           case 1 :
           {
               for(int i=0;i<rows;i++)
                   for(int j=0;j<cols;j++)
                       result[i][j] = p[i][j];
               break;
           }
           //右上
           case 2 :
           {
               for(int i=0;i<rows;i++)
                   for(int j=0;j<n-cols;j++)
                       result[i][j] = p[i][j+cols];
               break;
           }
           //左下
           case 3 :
           {
               for(int i=0;i<n-rows;i++)
                   for(int j=0;j<cols;j++)
                       result[i][j] = p[i+rows][j];
               break;
           }
           //右下
           case 4 :
           {
               for(int i=0;i<n-rows;i++)
                   for(int j=0;j<n-cols;j++)
                       result[i][j] = p[i+rows][j+cols];
               break;
           }
           default:
               break;
        }
       
        return result;
     }
   
    //把均分为四分之一的矩阵,合成一个矩阵
    public static int[][] TogetherMatrix(int[][] a,int[][] b,int[][] c,int[][] d,int n){
        int[][] result = new int[2*n][2*n];
        for(int i=0;i<2*n;i++){
            for(int j=0;j<2*n;j++){
                if(i<n){
                    if(j<n)
                        result[i][j] = a[i][j];
                    else
                        result[i][j] = b[i][j-n];
                }else{
                    if(j<n)
                        result[i][j] = c[i-n][j];
                    else
                        result[i][j] = d[i-n][j-n];
                }
            }
        }
        return result;
    }
   
   
    //求两个矩阵相加结果
    public static int[][] AddMatrix(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                result[i][j] = p[i][j]+q[i][j];
            }
        }
        return result;
    }
    
  //求两个矩阵相减结果
    public static int[][] ReduceMatrix(int[][] p,int[][] q,int n){
        int[][] result = new int[n][n];
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++){
                result[i][j] = p[i][j]-q[i][j];
            }
        }
        return result;
    }
    
    //输出矩阵的函数
    public static void PrintfMatrix(int[][] matrix,int n){
        for(int i=0;i<n;i++){
            for(int j=0;j<n;j++)
                System.out.printf("% 4d",matrix[i][j]);
            System.out.println();
        }
       
    }
   
    public static void main(String args[]){
        int[][] p = initializationMatrix(8);
        int[][] q = initializationMatrix(8);
        //输出生成的两个矩阵
        System.out.println("p:");
        PrintfMatrix(p,8);
        System.out.println();
        System.out.println("q:");
        PrintfMatrix(q,8);

        //输出分治法矩阵相乘后的结果
        int[][] bru_result = BruteForce(p,q,8);//新建一个矩阵存放矩阵相乘结果
        System.out.println();
        System.out.println("\np*q(蛮力法):");
        PrintfMatrix(bru_result,8);
        
        //输出分治法矩阵相乘后的结果
        int[][] dac_result = DivideAndConquer(p,q,8);//新建一个矩阵存放矩阵相乘结果
        System.out.println();
        System.out.println("p*q(分治法):");
        PrintfMatrix(dac_result,8);
        
        //输出strassen法矩阵相乘后的结果
        int[][] stra_result = StrassenMethod(p,q,8);//新建一个矩阵存放矩阵相乘结果
        System.out.println("\np*q(strassen法):");
        PrintfMatrix(stra_result,8);
        
    }


}

猜你喜欢

转载自blog.csdn.net/weixin_41668995/article/details/88847334