import torch
import numpy as np
切片与索引
a = torch.randint(1,10,[4,3,5,5])
a[0].shape
torch.Size([3, 5, 5])
a[1,2].shape
torch.Size([5, 5])
a[:,1]
tensor([[[9, 8, 7, 7, 8],
[1, 5, 6, 6, 3],
[5, 4, 5, 9, 5],
[5, 5, 9, 6, 8],
[4, 3, 5, 9, 9]],
[[7, 8, 7, 9, 6],
[9, 7, 5, 7, 3],
[7, 1, 9, 7, 1],
[5, 7, 9, 6, 8],
[4, 4, 2, 4, 5]],
[[8, 6, 7, 4, 8],
[8, 2, 8, 1, 1],
[3, 5, 4, 5, 1],
[7, 5, 5, 3, 3],
[3, 2, 5, 9, 1]],
[[4, 4, 4, 3, 7],
[9, 9, 9, 4, 9],
[3, 8, 1, 9, 4],
[1, 4, 8, 7, 3],
[1, 8, 1, 4, 6]]])
a[0,0,0,0].dim()
0
a[:2].shape
torch.Size([2, 3, 5, 5])
a[:2,:,1,:2]
tensor([[[2, 1],
[4, 7],
[8, 6]],
[[7, 5],
[9, 4],
[1, 4]]])
a.index_select(0,torch.tensor([1]))
tensor([[[[2, 8, 8, 7, 8],
[2, 7, 4, 2, 5],
[9, 6, 4, 8, 2],
[2, 3, 6, 3, 8],
[5, 2, 1, 3, 7]],
[[7, 8, 7, 9, 6],
[9, 7, 5, 7, 3],
[7, 1, 9, 7, 1],
[5, 7, 9, 6, 8],
[4, 4, 2, 4, 5]],
[[9, 8, 3, 9, 7],
[8, 9, 4, 2, 5],
[3, 5, 1, 1, 5],
[4, 1, 9, 8, 1],
[1, 2, 3, 5, 9]]]])
a.index_select(1,torch.tensor([0,2])).shape
torch.Size([4, 2, 5, 5])
a[0,...,-1]
tensor([[9, 2, 3, 8, 9],
[8, 3, 5, 8, 9],
[2, 4, 6, 5, 1]])
a[...,2,2]
tensor([[1, 5, 6],
[4, 9, 1],
[6, 4, 4],
[6, 1, 1]])
mask = a.ge(5)
torch.masked_select(a,mask)
tensor([5, 6, 9, 6, 6, 8, 5, 8, 9, 5, 9, 9, 9, 8, 7, 7, 8, 5, 6, 6, 5, 5, 9, 5,
5, 5, 9, 6, 8, 5, 9, 9, 8, 7, 8, 8, 8, 5, 6, 6, 6, 7, 5, 5, 7, 8, 8, 7,
8, 7, 5, 9, 6, 8, 6, 8, 5, 7, 7, 8, 7, 9, 6, 9, 7, 5, 7, 7, 9, 7, 5, 7,
9, 6, 8, 5, 9, 8, 9, 7, 8, 9, 5, 5, 5, 9, 8, 5, 9, 9, 8, 9, 5, 5, 5, 6,
5, 6, 6, 9, 9, 5, 7, 5, 8, 6, 7, 8, 8, 8, 5, 5, 7, 5, 5, 5, 9, 6, 5, 5,
5, 8, 6, 6, 9, 5, 9, 6, 5, 5, 5, 9, 7, 7, 6, 6, 7, 9, 9, 6, 7, 7, 9, 9,
9, 9, 8, 9, 8, 7, 8, 6, 6, 8, 8, 6, 7, 6, 7, 6, 5, 5, 7, 5])
src = torch.tensor([[1,2,3],[4,5,6]])
src
tensor([[1, 2, 3],
[4, 5, 6]])
src.take(torch.tensor([0,3,5]))
tensor([1, 4, 6])