Pytorch(2) 拼接,切分,索引,变换

                                                                                    张量的基本操作

                                                 pytorch

1.拼接

torch.cat()

  • 声明
torch.cat(tensors, dim=0, out=None) → Tensor
  • 功能:将张量按维度dim进行拼接

torch.stack()

  • 声明
torch.stack(tensors, dim=0, out=None) → Tensor
  • 功能:在新创建的维度dim进行拼接

注意:cat()不会扩展张量的维度,而stack()会拓展张量的维度。

2.切分:

torch.chunk()

  • 声明:
torch.chunk(input, chunks, dim=0) → List of Tensors
  • 功能:将张量按维度dim进行平均切分,返回张量列表。
  • 成员变量:
  1. input:要切分的张量
  2. chunks:要切分的份数
  3. dim:要切分的维度

注意:若不能整除,最后一项张量小于其他张量 

torch.split()

  • 声明:
torch.split(tensor, split_size_or_sections, dim=0)
  • 功能:将张量按维度dim进行切分,返回张量列表。
  • 成员变量:
  1. tensor:要切分的张量
  2. split_size_or_sections:为int时,表示每一份的长度;为list时,按list元素切分
  3. dim:要切分的维度

注意:如果使用list作为参数,则元素总和等于切分前的数量。

3.索引

torch.index_select()

  • 声明:
torch.index_select(input, dim, index, out=None) → Tensor
  • 功能:在维度dim上,按index索引数据,返回依index索引数据拼接的张量
  • 成员变量:
  1. input:要索引的张量
  2. dim:要索引的维度
  3. index:要索引数据的序号(LongTensor

torch.masked_select()

  • 声明:
torch.masked_select(input, mask, out=None) → Tensor
  • 功能:按mask中的true进行索引,返回一维张量
  • 成员变量:
  1. input:要索引的张量
  2. mask:与input同形状的布尔类型张量(ByteTensor

torch.ge(),gt(),le(),lt()

  • 声明:
torch.ge(input, other, out=None) → Tensor
torch.gt(input, other, out=None) → Tensor
torch.le(input, other, out=None) → Tensor
torch.lt(input, other, out=None) → Tensor
  •  功能:生成一个 input >= other,input > other,input <= other,input < other的bool Tensor.
  • 成员变量:
  1. input (Tensor) – the tensor to compare

  2. other (Tensor or python:float) – the tensor or value to compare

 4.变换:

torch.reshape()

  • 声明:
torch.reshape(input, shape) → Tensor
  • 功能:变换张量形状

注意:当张量在内存中是连续时,新张量与input共享数据内存

  • 成员变量:
  1. input:要变换的张量
  2. shape:新张量的形状(tuple of python)

注意:shape=(-1,2)中的-1表示我们不关心的维度,由系统决定。

torch.transpose()

  • 声明:
torch.transpose(input, dim0, dim1) → Tensor
  • 功能:交换张量的两个维度
  • 成员变量:
  1. input:要变换的张量
  2. dim0:要交换的维度
  3. dim1:要交换的维度

torch.t()

  • 声明:
torch.t(input) → Tensor
  • 功能:2维张量装置,对矩阵而言等价于torch.transpose(input,0,1)
  • 成员变量:
  1. input:要变换的张量

torch.squeeze()

  • 声明:
torch.squeeze(input, dim=None, out=None) → Tensor
  • 功能:压缩长度为1的维度(轴)
  • 成员变量:
  1. dim:若为None,移除所有长度1的轴;若指定维度,当且仅当该轴长度为1时,可以被移除;

torch.unsqueeze()

  • 声明:
unsqueeze_(dim) → Tensor
  • 功能:y依据dim扩展维度
  • 成员变量:
  1. dim:扩展的维度
发布了44 篇原创文章 · 获赞 26 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/He3he3he/article/details/102602256
今日推荐