在Python中轴是比较难懂概念,先从坐标轴说起。
n 维空间里有 n 个坐标轴,并且坐标轴互相垂直,每一个点相对于一条坐标轴都有唯一的一个坐标值。对同一条坐标轴来说,坐标值相同的点在同一个 n-1 维的“平面”上。任意取一个“平面”,我们就能定义“同一个坐标轴上的点”,这些点在“平面”上的投影相同,同一个坐标轴上的点组成的线是与坐标轴平行的。而所谓的延轴计算实际上是降维的过程,同一个坐标轴上的点合并成一个点,这样n维空间就变成了 n-1 维空间。
具体到 numpy 中的多维数组来说,轴即是元素坐标的索引。比如,第0轴即是第1个索引,延0轴计算就是去掉坐标中的第一个索引。过程就是
- 遍历其他索引的所有可能组合
- 取出一个组合,保持值不变,遍历第一个索引所有可能值
- 根据索引可以获得了同一个轴上的所有元素
- 对他们进行计算得到最后的元素
- 所有组合的最后结果组到一起就是最后的 n-1 维数组
沿轴计算过程,可以当做沿哪一个方向进行投影再进行计算。所以如果一个多维数组的 shape 是 (a1, a2, a3, a4), 那么延轴0计算最后的数组shape 是 (a2, a3, a4), 延轴1计算最后的数组shape是 (a1, a3, a4)
1,
1],[
2,
1],[
3,
1]],[[
4,
1],[
5,
1],[
6,
1]],[[
7,
1],[
8,
1],[
9,
1]]])
a = array([[[
(
3,
3,
2)
array([[[
1,
1],
[
2,
1],
[
3,
1]],
[[
4,
1],
[
5,
1],
[
6,
1]],
[[
7,
1],
[
8,
1],
[
9,
1]]])
0)
sum(a, axis=
array([[
12,
3],
[
15,
3],
[
18,
3]])
1)
sum(a, axis=
array([[
6,
3],
[
15,
3],
[
24,
3]])
2)
sum(a, axis=
array([[
2,
3,
4],
[
5,
6,
7],
[
8,
9,
10]])