DeepLearning4j实战1--ND4J矩阵操作

版权声明:本文为博主原创文章,转载请注明出处。 https://blog.csdn.net/qq_41089021/article/details/85045730

本文示例源码地址:https://github.com/tianlanlandelan/DL4JTest/blob/master/src/test/java/com/dl4j/demo/Nd4jTest.java

maven安装DL4J

pom文件引入:

		<dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native</artifactId>
            <version>1.0.0-beta3</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>1.0.0-beta3</version>
        </dependency>

创建矩阵

		//生成一个全0二维矩阵
        INDArray tensorA =  Nd4j.zeros(4,5);
        println("全0二维矩阵",tensorA);

        //生成一个全1二维矩阵
        INDArray tensorB =  Nd4j.ones(4,5);
        println("全0二维矩阵",tensorB);

        //生成一个全1二维矩阵
        INDArray tensorC =  Nd4j.rand(4,5);
        println("随机二维矩阵",tensorC);

运行结果:

====全0二维矩阵===
[[         0,         0,         0,         0,         0], 
 [         0,         0,         0,         0,         0], 
 [         0,         0,         0,         0,         0], 
 [         0,         0,         0,         0,         0]]
====全0二维矩阵===
[[    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000]]
====随机二维矩阵===
[[    0.5017,    0.9461,    0.3255,    0.2155,    0.9273], 
 [    0.0239,    0.5130,    0.8028,    0.5011,    0.3680], 
 [    0.3644,    0.0864,    0.0342,    0.4126,    0.5553], 
 [    0.2027,    0.7989,    0.6696,    0.0402,    0.7059]]

矩阵运算–拼接

println("水平拼接若干矩阵,矩阵必须有相同的行数", Nd4j.hstack(tensorA,tensorB));
println("垂直拼接若干矩阵,矩阵必须有相同的列数", Nd4j.vstack(tensorA,tensorB));

运行结果:

[[         0,         0,         0,         0,         0,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [         0,         0,         0,         0,         0,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [         0,         0,         0,         0,         0,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [         0,         0,         0,         0,         0,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000]]
====垂直拼接若干矩阵,矩阵必须有相同的列数===
[[         0,         0,         0,         0,         0], 
 [         0,         0,         0,         0,         0], 
 [         0,         0,         0,         0,         0], 
 [         0,         0,         0,         0,         0], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000]]

矩阵运算-加减

注意,每个运算函数都有一个加i的函数,如 add和addi,加i的函数运算后会覆盖掉原矩阵

        println("矩阵元素加上一个标量",tensorA.add(10));
        println("矩阵相加",tensorA.add(tensorB));
        println("矩阵元素加上标量后覆盖原矩阵tensorA",tensorA.addi(10));
        println("矩阵相减",tensorA.sub(tensorB));

运行结果:

====矩阵元素加上一个标量===
[[   10.0000,   10.0000,   10.0000], 
 [   10.0000,   10.0000,   10.0000]]
====矩阵相加===
[[    0.2202,    0.1473,    0.1217], 
 [    0.8428,    0.6761,    0.8127]]
====矩阵元素加上标量后覆盖原矩阵tensorA===
[[   10.0000,   10.0000,   10.0000], 
 [   10.0000,   10.0000,   10.0000]]
====矩阵相减===
[[    9.7798,    9.8527,    9.8783], 
 [    9.1572,    9.3239,    9.1873]]

矩阵运算-乘除

 		 println("矩阵对应元素相乘",tensorA.mul(tensorB));
        println("矩阵元素除以一个标量",tensorA.div(2));
        println("矩阵对应元素相除",tensorA.div(tensorB));
        /*
        矩阵A*B=C
        需要注意:
        1、当矩阵A的列数等于矩阵B的行数时,A与B可以相乘。
        2、矩阵C的行数等于矩阵A的行数,C的列数等于B的列数。( A:2,3; B:3,4; C:2,4 )
        3、乘积C的第m行第n列的元素等于矩阵A的第m行的元素与矩阵B的第n列对应元素乘积之和。
         */
        println("矩阵相乘",tensorA.mmul(tensorB));

运算结果:

====矩阵对应元素相乘===
[[    2.2015,    1.4728,    1.2173], 
 [    8.4281,    6.7608,    8.1272]]
====矩阵元素除以一个标量===
[[    5.0000,    5.0000,    5.0000], 
 [    5.0000,    5.0000,    5.0000]]
====矩阵对应元素相除===
[[   45.4231,   67.8989,   82.1506], 
 [   11.8650,   14.7911,   12.3043]]
====矩阵相乘===
[[    4.8916,   23.3161], 
 [    4.8916,   23.3161]]

矩阵运算-翻转

	println("矩阵转置",tensorB.transpose());
    println("矩阵转置后替换原矩阵tensorB",tensorB.transposei());

运算结果:

====矩阵转置===
[[    0.2202,    0.8428], 
 [    0.1473,    0.6761], 
 [    0.1217,    0.8127]]
====矩阵转置后替换原矩阵tensorB===
[[    0.2202,    0.8428], 
 [    0.1473,    0.6761], 
 [    0.1217,    0.8127]]

三维矩阵

三维矩阵和二维矩阵操作一样:

        //创建一个三维矩阵 2*3*4
        INDArray tensor3d_1 = Nd4j.create(new int[]{2,3,4});
        println("创建空的三维矩阵",tensor3d_1);
        //创建一个随机的三维矩阵 2*3*4
        INDArray tensor3d_2 =  Nd4j.rand(new int[]{2,3,4});
        println("创建随机三维矩阵",tensor3d_2);
        //矩阵的每个元素减去一个标量后覆盖原矩阵
        println("矩阵元素减去一个标量",tensor3d_1.subi(-5));
        //矩阵相减
        println("三维矩阵相减",tensor3d_1.sub(tensor3d_2));

运算结果:

====创建空的三维矩阵===
[[[         0,         0,         0,         0], 
  [         0,         0,         0,         0], 
  [         0,         0,         0,         0]], 

 [[         0,         0,         0,         0], 
  [         0,         0,         0,         0], 
  [         0,         0,         0,         0]]]
====创建随机三维矩阵===
[[[    0.7030,    0.0575,    0.3288,    0.8928], 
  [    0.7067,    0.4539,    0.6318,    0.8632], 
  [    0.2914,    0.7980,    0.3350,    0.8783]], 

 [[    0.8559,    0.7396,    0.6039,    0.1946], 
  [    0.5336,    0.9253,    0.4747,    0.2658], 
  [    0.9690,    0.3269,    0.0520,    0.1754]]]
====矩阵元素减去一个标量===
[[[    5.0000,    5.0000,    5.0000,    5.0000], 
  [    5.0000,    5.0000,    5.0000,    5.0000], 
  [    5.0000,    5.0000,    5.0000,    5.0000]], 

 [[    5.0000,    5.0000,    5.0000,    5.0000], 
  [    5.0000,    5.0000,    5.0000,    5.0000], 
  [    5.0000,    5.0000,    5.0000,    5.0000]]]
====三维矩阵相减===
[[[    4.2970,    4.9425,    4.6712,    4.1072], 
  [    4.2933,    4.5461,    4.3682,    4.1368], 
  [    4.7086,    4.2020,    4.6650,    4.1217]], 

 [[    4.1441,    4.2604,    4.3961,    4.8054], 
  [    4.4664,    4.0747,    4.5253,    4.7342], 
  [    4.0310,    4.6731,    4.9480,    4.8246]]]

猜你喜欢

转载自blog.csdn.net/qq_41089021/article/details/85045730