pytorch:tensor的合并与分割

import torch

cat

  • dim后面接的数字表示从哪个维度上进行拼接,注意:其他维度的size必须要两两相等
a = torch.rand(30,6)#a记录了班上前30个同学的6门考试成绩
b = torch.rand(25,6)#b记录了班上后30个同学的6门考试成绩
torch.cat([a,b],dim=0).shape 
torch.Size([55, 6])
a = torch.rand(5,3) #5个同学,a只记录了4门课的成绩
b = torch.rand(5,2)#还是这5个同学,b记录了另外2门课的成绩
c = torch.rand(5,1)#还是这5个同学,c记录了另外1门课的成绩
torch.cat([a,b,c],dim=1).shape #合并成一张成绩单,包含每个人的6门成绩
torch.Size([5, 6])

stack

  • stack在一起的tensor的每个dim的size都必须要一样
  • dim后面接的数字表示要在该维度前面增加一个维度,该维度的size是stack起来的tensor数目
a = torch.rand(30,6)
b = torch.rand(30,6)
torch.stack([a,b],dim=0).shape
torch.Size([2, 30, 6])
a = torch.rand(2,3,4)
b = torch.rand(2,3,4)
c = torch.rand(2,3,4)
torch.stack([a,b,c],dim=1).shape
torch.Size([2, 3, 3, 4])

split

  • dim表示要从那个维度上切分
  • 只输入数字,则表示每组有多少个tensor
  • 如果输入的是一个list,那么就按照list里面的数字进行切分
a = torch.rand(6,20,5) #a表示有6个班级,每个班有20个人,每个人有5门考试成绩
a1,a2,a3=a.split(2,dim=0) #按照顺序,每两个班级分为一组
print(a1.shape,a2.shape,a3.shape)
torch.Size([2, 20, 5]) torch.Size([2, 20, 5]) torch.Size([2, 20, 5])
a = torch.rand(6,20,5) #a表示有6个班级,每个班有20个人,每个人有5门考试成绩
a1,a2,a3=a.split([1,2,3],dim=0) #按照顺序,第一个班级分为一组,第二、三个班级分成一组,第4、5、6个班级分成一组
print(a1.shape,a2.shape,a3.shape)
torch.Size([1, 20, 5]) torch.Size([2, 20, 5]) torch.Size([3, 20, 5])

chunk

  • dim表示要从那个维度上切分
  • 输入的数字表示在这个维度上要切分成多少组
a = torch.rand(6,20,5) #a表示有6个班级,每个班有20个人,每个人有5门考试成绩
a1,a2,a3=a.chunk(3,dim=0) #3表示分成3组tensor
print(a1.shape,a2.shape,a3.shape)
torch.Size([2, 20, 5]) torch.Size([2, 20, 5]) torch.Size([2, 20, 5])
发布了43 篇原创文章 · 获赞 1 · 访问量 759

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104683632