Numpy 之 axis

转自:Python、numpy 与 axis

这次和大家分享的是 numpy 中的 axis 这个东西。当初学的时候也没太在意,向来都是感觉差不多就直接过去了,没有去深究背后的一些逻辑。前些天被问起的时候一时懵懂,查了下资料后发现还有点意思,于是就打算写这么一篇专栏来分享一下所得。

要想学习 axis,首先要知道的就是 axis 的计数方式。我们在使用 numpy 的各种函数——比如说 np.sum——的时候,有一个参数就叫做 axis。那么这个参数的意思是什么呢?最直白地来说的话,就是“最外面的括号代表着 axis=0,依次往里的括号对应的 axis 的计数就依次加 1”。

举个例子,现在我们有一个矩阵:x= \begin{bmatrix} 0 & 1\\ 2 & 3 \end{bmatrix} ;在 Python,或说在 numpy 里面,这个矩阵是这样被表达出来的:x = [ [0, 1], [2, 3] ],然后 axis 的对应方式就是:

不管画风怎么变,很丑这一点都无法改变啊……

所以相应的运算就是:

对应的代码实现和运行结果如下:

可以看到,貌似出来的结果比我们推导的结果的括号要少一些。这是因为诸如 np.sum 这种函数中有一个参数叫 keepdims,它的默认值是 False,此时它会把多余的括号给删掉。假如我们把它设为 True 的话,就可以得到和推导中一致的结果了:

下面来看一个更“高维”一点的例子:

对应的代码实现和运行结果如下:

以及

可以看到结果和我们推导的确实一样。

现在我们知道哪个 axis 对应于数组中的哪些元素了,接下来还需要知道的就是 transpose 这个函数到底在背后干了什么。从纸面上来看,如果一个高维数组 x 的 shape 是 (2, 3, 4),那么 transpose 的作用就是把这个 shape 中各个数的顺序改一改。比如说:

d2334c4b05886e5c76840700009efb07278f039f.jpeg

但是 transpose 返回的结果究竟是如何得到的,可能就比较难理解了。幸运的是,这个回答 2非常好地阐明了这背后的原理。为了方便观众老爷们,我在这里就当一个搬运工 and 润色工。

首先是对这个 shape 的理解。直观地说,shape 中的各个数就是对应 axis 的元素个数。比如说上图中的 x,它画出来会是这个样子的:

字比画还丑呢……

如果我们换一种思路的话,以 axis=0 为例,由于我们现在整个数组里面一共有 24 个数,而 axis=0 只有两个元素,所以可以理解为在 axis=0 这个 axis 上,每隔 24 / 2 = 12 个数就跳一下。比如说上面这个图中就可以看出,两个橙色矩阵对应的数之间差的都是 12。

类似的,由于一个橙色矩阵中只有 24 / 2 = 12 个数,所以我们可以理解为在 axis=1 这个 axis 上,每隔 12 / 3 = 4 个数就跳一下。表现在图中,就是同一个橙色矩阵的两个相邻的蓝色向量对应的数之间差的都是 4。

再次类似的,由于一个蓝色向量中只有 12 / 3 = 4 个数,我们可以理解为在 axis=2 这个 axis 上,每隔 4 / 4 = 1 个数就跳一下。表现在图中……观众老爷们想必也知道是怎样的了。

所以我们现在可以定义一个新的东西,比如说叫做 strides 吧,它记录着每个 axis 上跳过的数。比如说上图对应的三维数组,它的 strides 就是 (12, 4, 1)。

那么接下来激动人心的时刻到了:transpose 的本质,其实就是对 strides 中各个数的顺序进行调换。举个例子:

在 transpose(1, 0, 2) 后,相应的 strides 会变成 (4, 12, 1)。而从上图可以看出,transpose 的结果确实满足:

  • axis=0 的 axis 上,每隔 4 个数跳一下
  • axis=1 的 axis 上,每隔 12 个数跳一下
  • axis=2 的 axis 上,每隔 1 个数跳一下

至此,transpose 背后的逻辑就理顺啦!撒花!

发布了39 篇原创文章 · 获赞 12 · 访问量 16万+

猜你喜欢

转载自blog.csdn.net/wdh315172/article/details/105412493