张量的拼接
torch.cat()
torch.cat(tensors,
dim=0,
out=None)
功能: 将张量按维度dim进行拼接
- tensors: 张量序列
- dim : 要拼接的维度
t = torch.ones((2, 3))
q = torch.zeros((2, 3))
t0 = torch.cat([t, q], 0)
t1 = torch.cat((t, q), dim=1)
print(t0, t0.shape)
print(t1, t1.shape)
tensor([[1., 1., 1.],
[1., 1., 1.],
[0., 0., 0.],
[0., 0., 0.]]) torch.Size([4, 3])
tensor([[1., 1., 1., 0., 0., 0.],
[1., 1., 1., 0., 0., 0.]]) torch.Size([2, 6])
torch.stack()
torch.stack(tensors,
dim=0,
out=None)
功能: 在新创建的维度dim上进行拼接
- tensors:张量序列
- dim :要拼接的维度
t = torch.ones((3, 4))
q = torch.zeros((3, 4))
t0 = torch.stack([t, q], dim=0)
t1 = torch.stack([t, q], dim=1)
t2 = torch.stack([t, q], dim=2)
print(t0, t0.shape)
print(t1, t1.shape)
print(t2, t2.shape)
tensor([[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]],
[[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]]) torch.Size([2, 3, 4])
tensor([[[1., 1., 1., 1.],
[0., 0., 0., 0.]],
[[1., 1., 1., 1.],
[0., 0., 0., 0.]],
[[1., 1., 1., 1.],
[0., 0., 0., 0.]]]) torch.Size([3, 2, 4])
tensor([[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]],
[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]],
[[1., 0.],
[1., 0.],
[1., 0.],
[1., 0.]]]) torch.Size([3, 4, 2])
张量的切分
torch.chunk()
torch.chunk(input,
chunks,
dim=0)
功能: 将张量按维度dim进行平均切分
返回值: 张量列表
注意事项: 若不能整除,最后一份张量小于其他张量
- input: 要切分的张量
- chunks : 要切分的份数
- dim : 要切分的维度
t = torch.ones((2, 7))
list_t = torch.chunk(t, chunks=3, dim=1)
for i, ten in enumerate(list_t):
print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第3个张量:
tensor([[1.],
[1.]]) torch.Size([2, 1])
torch.split()
torch.split(tensor,
split_size_or_sections,
dim=0)
功能: 将张量按维度dim进行切分
返回值: 张量列表
- tensor: 要切分的张量
- split_size_or_sections : 为int时,表示每一份的长度;为list时,按list元素切分
- dim : 要切分的维度
t = torch.ones((2, 7))
list_t = torch.split(t, 3, dim=1)
for i, ten in enumerate(list_t):
print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
print("\n")
list_t = torch.split(t, [3, 4], dim=1)
for i, ten in enumerate(list_t):
print("第{}个张量:\n{}".format(i+1, ten), ten.shape)
第1个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第3个张量:
tensor([[1.],
[1.]]) torch.Size([2, 1])
第1个张量:
tensor([[1., 1., 1.],
[1., 1., 1.]]) torch.Size([2, 3])
第2个张量:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]]) torch.Size([2, 4])
张量的比较
torch.ge(),torch.gt(),torch.le(),torch.lt()
torch.ge(input,
other,
out=None)
功能: input中逐元素与other进行比较,满足:ge >=; gt >; le <=; lt <时,返回True
返回值: 与input同形状的布尔类型张量
- input:被比较的张量
- other:可以是张量,数值,布尔,input中逐元素与其进行比较
t = torch.randint(0, 10, [3, 3])
m = t.ge(5)
print(t)
print(m)
tensor([[1, 6, 5],
[6, 5, 4],
[0, 4, 4]])
tensor([[False, True, True],
[ True, True, False],
[False, False, False]])
张量的索引
torch.index_select()
torch.index_select(input,
dim,
index,
out=None)
功能: 在维度dim上,按index索引数据
返回值: 索引得到的数据拼接的张量
- input: 要索引的张量
- dim: 要索引的维度
- index : 要索引数据的序号组成的张量,dtype须为torch.long
t = torch.randint(0, 10, [3, 3])
idx = torch.tensor([0, 2], dtype=torch.long)
sel = torch.index_select(t, 0, idx)
print(t)
print(sel)
tensor([[0, 7, 0],
[8, 3, 1],
[2, 7, 9]])
tensor([[0, 7, 0],
[2, 7, 9]])
torch.masked_select()
torch.masked_select(input,
mask,
out=None)
功能: 按mask中的True进行索引
返回值: 一维张量
- input: 要索引的张量
- mask: 与input同形状的布尔类型张量
t = torch.randint(0, 10, [3, 3])
mask = t.ge(5)
sel = torch.masked_select(t, mask)
print(t)
print(mask)
print(sel)
tensor([[1, 6, 5],
[6, 5, 4],
[0, 4, 4]])
tensor([[False, True, True],
[ True, True, False],
[False, False, False]])
tensor([6, 5, 6, 5])
张量的变换
torch.reshape()
torch.reshape(input,
shape)
功能: 变换张量形状
注意事项: 当张量在内存中是连续的时,新张量与input共享数据内存。这种共享与out不同,out是整个tensor都共享内存,相当于别名;reshape是仅data共享内存。改变一个张量的数据,另一个张量会跟着改变
- input: 要变换的张量
- shape: 新张量的形状
t = torch.randperm(8)
re1 = torch.reshape(t, (2, 4))
re2 = torch.reshape(t, (-1, 4))
print(t)
print(re1)
print(re2)
t[0] = 100
re2[1, 1] = 100
print(id(t.data), id(re1.data), id(re2.data))
print(re1)
tensor([0, 7, 2, 6, 3, 5, 4, 1])
tensor([[0, 7, 2, 6],
[3, 5, 4, 1]])
tensor([[0, 7, 2, 6],
[3, 5, 4, 1]])
3039469824264 3039469824264 3039469824264
tensor([[100, 7, 2, 6],
[ 3, 100, 4, 1]])
torch.transpose()
torch.transpose(input,
dim0,
dim1)
功能: 交换张量的两个维度。在图像的预处理中常用,有时读取的图像数据是(c, h, w),但是我们常用的是(h, w, c),就需要用此方法把channel和width变换,再把width和height变换
- input: 要变换的张量
- dim0: 要交换的维度
- dim1: 要交换的维度
t = torch.rand((2, 3, 4))
tr = torch.transpose(t, 1, 0)
print(t, t.shape)
print(tr, tr.shape)
tensor([[[0.4973, 0.0644, 0.7269, 0.9305],
[0.4711, 0.1117, 0.1751, 0.4904],
[0.9865, 0.7374, 0.9201, 0.5733]],
[[0.4911, 0.4571, 0.9985, 0.7298],
[0.5078, 0.0928, 0.1655, 0.8740],
[0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([2, 3, 4])
tensor([[[0.4973, 0.0644, 0.7269, 0.9305],
[0.4911, 0.4571, 0.9985, 0.7298]],
[[0.4711, 0.1117, 0.1751, 0.4904],
[0.5078, 0.0928, 0.1655, 0.8740]],
[[0.9865, 0.7374, 0.9201, 0.5733],
[0.8735, 0.7616, 0.0533, 0.4300]]]) torch.Size([3, 2, 4])
torch.t()
torch.t(input)
功能: 2维张量转置,对矩阵而言,等价于torch.transpose(input, 0, 1)
torch.squeeze()
torch.squeeze(input,
dim=None,
out=None)
功能: 压缩长度为1的维度(轴)
- dim: 若为None,移除所有长度为1的轴; 若指定维度,当且仅当该轴长度为1时,可以被移除
t = torch.rand((1, 2, 3, 1))
sq = torch.squeeze(t)
sq0 = torch.squeeze(t, 0)
sq1 = torch.squeeze(t, 1)
print(t.shape)
print(sq.shape)
print(sq0.shape)
print(sq1.shape)
torch.Size([1, 2, 3, 1])
torch.Size([2, 3])
torch.Size([2, 3, 1])
torch.Size([1, 2, 3, 1])
torch.unsqueeze()
torch.usqueeze(input,
dim,
out=None)
功能:依据dim扩展维度
- dim: 扩展的维度
t = torch.rand((2, 3))
sq = torch.unsqueeze(t, 0)
print(t.shape)
print(sq.shape)
torch.Size([2, 3])
torch.Size([1, 2, 3])