前言
如果没有时间看下去,这里直接告诉你结论:
-
两者都是用来重塑tensor的shape的。
-
view只适合对满足连续性条件(contiguous)的tensor进行操作,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。
-
reshape对适合对满足连续性条件(contiguous)的tensor进行操作返回值是视图,否则返回副本(此时等价于先调用contiguous()方法在使用view())
-
考虑内存的开销而且要确保重塑后的tensor与之前的tensor共享存储空间,那就使用view
-
view能干的reshape都能干 如果只是重塑一个tensor的shape 那就无脑选择reshape
pytorch Tensor 介绍
想要深入理解view与reshape的区别,首先要理解一些有关PyTorch张量存储的底层原理,比如tensor的头信息区(Tensor)和存储区 (Storage)以及tensor的步长Stride
Tensor 存储结构介绍
tensor数据采用头信息区(Tensor)和存储区 (Storage)分开存储的形式,如图1所示。变量名以及其存储的数据是分为两个区域分别存储的。比如,我们定义并初始化一个tensor,tensor名为A,A的形状size、步长stride、数据的索引等信息都存储在头信息区,而A所存储的真实数据则存储在存储区。另外,如果我们对A进行截取、转置或修改等操作后赋值给B,则B的数据共享A的存储区,存储区的数据数量没变,变化的只是B的头信息区对数据的索引方式。如果听说过浅拷贝和深拷贝的话,很容易明白这种方式其实就是浅拷贝。
代码示例如下:
import torch
a = torch.arange(5) # 初始化张量 a 为 [0, 1, 2, 3, 4]
b = a[2:] # 截取张量a的部分值并赋值给b,b其实只是改变了a对数据的索引方式
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr()) # 打印a的存储区地址
print('ptr of storage of b:', b.storage().data_ptr()) # 打印b的存储区地址,可以发现两者是共用存储区
print('==================================================================')
b[1] = 0 # 修改b中索引为1,即a中索引为3的数据为0
print('a:', a)
print('b:', b)
print('ptr of storage of a:', a.storage().data_ptr()) # 打印a的存储区地址,可以发现a的相应位置的值也跟着改变,说明两者是共用存储区
print('ptr of storage of b:', b.storage().data_ptr()) # 打印b的存储区地址
Tensor的步长(stride)属性
torch的tensor也是有步长属性的,说起stride属性是不是很耳熟?是的,卷积神经网络中卷积核对特征图的卷积操作也是有stride属性的,但这两个stride可完全不是一个意思哦。tensor的步长可以理解为从索引中的一个维度跨到下一个维度中间的跨度
我们看下如下例子:
import torch
a = torch.arange(6).reshape(2, 3) # 初始化张量 a
b = torch.arange(6).view(3, 2) # 初始化张量 b
print('a:', a)
print('stride of a:', a.stride()) # 打印a的stride
print('b:', b)
print('stride of b:', b.stride()) # 打印b的stride
Tensor View 理解
参考链接
大致意思是:
返回的张量共享相同的数据,并且必须具有相同数量的元素,但可能具有不同的大小。对于要查看的张量,新视图大小必须与其原始大小和步幅兼容,即每个新视图维度必须是原始维度的子空间,或者满足以下连续条件:
否则需要先使用contiguous()方法将原始tensor转换为满足连续条件的tensor,然后就可以使用view方法进行shape变换了。或者直接使用reshape方法进行维度变换,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。
如何理解tensor是否满足连续条件呐?下面通过一系列例子来慢慢理解下
查看tensor的stride、size属性
如下例子:
我们可以看到结果是满足连续性的 stride[0] = 3 = 1X3
下面我们看看不满足连续性的例子:
import torch
a = torch.arange(9).reshape(3, 3) # 初始化张量a
b = a.permute(1, 0) # 对a进行转置
print('struct of b:\n', b)
print('size of b:', b.size()) # 查看b的shape
print('stride of b:', b.stride()) # 查看b的stride
''' 运行结果 '''
struct of b:
tensor([[0, 3, 6],
[1, 4, 7],
[2, 5, 8]])
size of b: torch.Size([3, 3])
stride of b: (1, 3) # 注:此时不满足连续性条件
输出a和b的存储区来看一下有没有什么不同:
import torch
a = torch.arange(9).reshape(3, 3) # 初始化张量a
print('ptr of storage of a: ', a.storage().data_ptr()) # 查看a的storage区的地址
print('storage of a: \n', a.storage()) # 查看a的storage区的数据存放形式
b = a.permute(1, 0) # 转置
print('ptr of storage of b: ', b.storage().data_ptr()) # 查看b的storage区的地址
print('storage of b: \n', b.storage()) # 查看b的storage区的数据存放形式
''' 运行结果 '''
ptr of storage of a: 1899603060672
storage of a:
0
1
2
3
4
5
6
7
8
[torch.LongStorage of size 9]
ptr of storage of b: 1899603060672
storage of b:
0
1
2
3
4
5
6
7
8
[torch.LongStorage of size 9]
由结果可以看出,张量a、b仍然共用存储区,并且存储区数据存放的顺序没有变化,这也充分说明了b与a共用存储区,b只是改变了数据的索引方式。那么为什么b就不符合连续性条件了呐(T-T)?其实原因很简单,我们结合图3来解释下:
Torch.reshape
作用:与view方法类似,将输入tensor转换为新的shape格式。
但是reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()。
即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同
参考文章和链接
https://blog.csdn.net/Flag_ing/article/details/109129752
https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view
https://stackoverflow.com/questions/49643225/whats-the-difference-between-reshape-and-view-in-pytorch