【笔记】torch 乘法总结

【笔记】torch 乘法总结

一、乘号(*) 和 torch.mul()

element-wise 即对应元素相乘

例子:

>>> a = torch.randn(2,3)
>>> b = torch.randn(2,1)
>>> res = a * b
>>> res
tensor([[-0.9672, -0.1052,  0.1392],
        [-0.8552,  0.8967, -0.6433]])

特别地,如果是( 2 × 1 × 3 2\times 1\times 3 2×1×3)和( 2 × 4 × 3 2\times 4\times 3 2×4×3)这种情况,也可以相乘,结果是( 2 × 4 × 3 2\times 4\times 3 2×4×3)。相当于用前一个tensor沿着后一个tensor的第1维度expand(broadcast机制)。

例子:

>>> a = torch.randn(2,3)
>>> b = torch.randn(2,1)
>>> res1 = a * b
>>> res1
tensor([[ 0.9199,  0.4053,  0.0789],
        [ 2.1330, -0.5653,  0.4760]])
>>> res2 = torch.mul(a,b)
>>> res2
tensor([[ 0.9199,  0.4053,  0.0789],
        [ 2.1330, -0.5653,  0.4760]])	# torch.mul() 和 * 效果相同
>>> res3 = a * b.expand(2,3)
>>> res3
tensor([[ 0.9199,  0.4053,  0.0789],	# 两个tensor维度不一致时,
        [ 2.1330, -0.5653,  0.4760]])	# 会自动进行expand

二、torch.mm() 与 torch.bmm()

矩阵乘法。

  • torch.mm(mat1, mat2) m a t 1 mat1 mat1 m a t 2 mat2 mat2 进行矩阵乘法,要求输入只能是2维矩阵。

  • torch.bmm(mat1, mat2)专门进行batch形式的矩阵乘法。要求(1)输入只能是3维矩阵( b a t c h , d 1 , d 2 batch, d_1, d_2 batch,d1,d2);(2)第0维度相同。

三、torch.matmul()

矩阵乘法,支持broadcast。

下面是torch.matmul(mat1, mat2)的适用情形:

  • m a t 1 mat1 mat1 m a t 2 mat2 mat2 都是一维向量,则结果返回的是一个标量

  • (标准的二维矩阵乘法)若 m a t 1 ∈ R m × n mat1 \in \mathbb{R}^{m \times n} mat1Rm×n, m a t 2 ∈ R n × d mat2 \in \mathbb{R}^{n \times d} mat2Rn×d,则返回两个矩阵乘积 o u t p u t ∈ R m × d output\in \mathbb{R}^{m\times d} outputRm×d

  • m a t 1 ∈ R n mat1 \in \mathbb{R}^{n} mat1Rn 是一维向量,则会为 m a t 1 mat1 mat1 添加一个维度,变成 m a t 1 ∈ R 1 × n mat1 \in \mathbb{R}^{1 \times n} mat1R1×n,然后与矩阵2 m a t 2 ∈ R n × d mat2 \in \mathbb{R}^{n \times d} mat2Rn×d 进行二维矩阵乘法。并在最后的结果中移除添加的那一维度 => R d \mathbb{R}^{d} Rd

  • 若至少有一个参与运算的矩阵的维度大于2,则进行batch矩阵乘法

    这里会把矩阵的后两个维度视作矩阵维度(matrix dimensions),参与矩阵运算,其他维度视作batch维度,进行broadcast处理。

    两个例子:

    (1) mat1 是一个 size 为 ( j × 1 × n × n j \times 1 \times n \times n j×1×n×n) 的 tensor ,mat2 是一个 size 为 ( k × n × n k \times n \times n k×n×n) 的 tensor, out 将会是一个 ( j × k × n × n ) j \times k \times n \times n) j×k×n×n) 的 tensor. 这里 ( n × n n \times n n×n) 部分是矩阵维度,( k k k) 和 ( j × k j \times k j×k) 是 batch 维度。

    (2) mat1 是一个 size 为 ( j × 1 × n × m j \times 1 \times n \times m j×1×n×m) 的 tensor , mat2 是一个 size 为 ( k × m × p k \times m \times p k×m×p) 的 tensor, 这些输入的 tensors 支持 broadcasting 机制,即使最后两个矩阵维度是不同的。 输出out 将会是一个 size 为 ( j × k × n × p j \times k \times n \times p j×k×n×p) 的 tensor.

来源:torch.matmul — PyTorch 1.8.1 documentation

猜你喜欢

转载自blog.csdn.net/weixin_45850137/article/details/116501102