矩阵处理

有关矩阵处理的函数实在是太多了,在这里写一下,方便以后回忆
torch.clamp
clamp(input,min,max,out=None)-> Tensor,将input中的元素限制在[min,max]范围内并返回一个Tensor
可以看一下栗子:

a = torch.randn(2,2)
print(a)
c = a.clamp(min = 0)
print(c)

输出。将所有小于min的数全部替换成min

tensor([[ 0.6887, -2.4910],
        [-1.1766, -0.9142]])
tensor([[0.6887, 0.0000],
        [0.0000, 0.0000]])

当把min替换成max后

c = a.clamp(max = 0)

看一下输出

tensor([[0.1676, 1.2737],
        [0.8410, 0.2963]])
tensor([[0., 0.],
        [0., 0.]])

torch.diag
取矩阵的对角元素,组成一个新的tensor,输出

b =a.diag()

输出:

tensor([[ 0.8293, -1.3060],
        [-0.4743,  0.8271]])
tensor([0.8293, 0.8271])

torch.eye
torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
可以返回参数维度的单位矩阵

d = torch.eye(3)
print(d)

输出:

tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

torch.max
按照维度取最大值,返回一个tuple(tensor,dtype)

vvv = torch.randn(2,2)
print(vvv)
vv = vvv.max(1)
print(vv)

等于1,是按照行取最大;等于0,按照列取最大,输出:

tensor([[ 0.5618,  0.7352],
        [-0.0405, -1.2378]])
(tensor([ 0.7352, -0.0405]), tensor([1, 0]))

所以取max(1)[0]就是取元祖的第一个元素

扫描二维码关注公众号,回复: 5707246 查看本文章

torch.masked_fill_(mask,value)
在mask值为1的地方用value填充

print(d)
cost_im = d.masked_fill_(d, 3)
print(cost_im)

输出:

tensor([[1, 0],
        [0, 1]], dtype=torch.uint8)
tensor([[3, 0],
        [0, 3]], dtype=torch.uint8)

另外的矩阵调用

a = torch.randn(2,2)
print(a)

输出:

tensor([[-0.4149, -1.3434],
        [ 0.2611, -0.4930]])
print(d) 
cost_im = a.masked_fill_(d, 3)
print(cost_im)

将原矩阵为1的地方用3代替,其余用a中的对应元素填充,输出:

tensor([[1, 0],
        [0, 1]], dtype=torch.uint8)
tensor([[ 3.0000, -1.3434],
        [ 0.2611,  3.0000]])

np.argsort(d, axis=1)
argsort()函数的作用是将数组按照从小到大的顺序排序,并按照对应的索引值输出。
argsort()函数中,当axis=0时,按列排列;当axis=1时,按行排列。如果省略默认按行排列。

inds = np.argsort(d, axis=1)
print(inds)
[[11  5]
 [25 11]]
[[1 0]
 [1 0]]

np.where(condition,x,y)和np.where()
满足条件(condition),输出x,不满足输出y。
只有条件 (condition),没有x和y,则输出满足条件 (即非0) 元素的坐标 (等价于numpy.nonzero)。这里的坐标以tuple的形式给出,通常原数组有多少维,输出的tuple中就包含几个数组,分别对应符合条件元素的各维坐标。
eg:

array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 9, 10, 11],
        [12, 13, 14],
        [15, 16, 17]],

       [[18, 19, 20],
        [21, 22, 23],
        [24, 25, 26]]])

>>> np.where(a > 5)
(array([0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2]),
 array([2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2, 0, 0, 0, 1, 1, 1, 2, 2, 2]),
 array([0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2]))

每个数都用三维表示,例如符合条件的6,7,8,坐标就分别是(0,2,0),(0,2,1),(0,2,2);可以竖着看,每一个坐标是一个点
矩阵定义
为什么这么简单的一直记不住!!np.array()!!!

猜你喜欢

转载自blog.csdn.net/weixin_38267508/article/details/86597167