numpy.unravel_index()
函数的作用是获取一个/组int
类型的索引值在一个多维数组中的位置。
语法:
numpy.unravel_index(indices, dims)
举个简单的例子
求一个多维数组的最大元素的索引
A = np.random.randint(1,100,size=(2,3,5)) # 声明一个size=(3,3,3,2)的数组 print(A) array([[[98, 29, 32, 73, 90], [36, 52, 24, 2, 37], [66, 80, 23, 29, 98]], [[17, 32, 58, 99, 74], [53, 3, 20, 48, 28], [53, 7, 74, 34, 68]]]) ind_max = np.argmax(A) print(ind_max)
18 # 此时得到的索引是将A数组flattern(展成一维数组)后的索引,如何得到对应的原数组的索引呢? ind_max_src = np.unravel_index(ind_max, A.shape) print(ind_max_src) (1, 0, 3)
# 函数numpy.unravel_index(indices, dims)返回的索引值从0开始计数。 print(A[ind_max_src]) 99
如果np.unravel_index(indices, dims)的第一个参数是一个int型的数组
则返回该数组中每个元素(即flattern索引值)对应的原数组的索引。
idx = np.unravel_index((0,29),A.shape)
# 注意索引下标从0开始 print(idx) # (array([0, 1], dtype=int64), array([0, 2], dtype=int64), array([0, 4], dtype=int64)) print(idx[0]) print(idx[1]) print(idx[2]) [0 1] [0 2] [0 4] first_idx = (idx[0][0],idx[1][0],idx[2][0]) print(first_idx) print(A[first_idx]) (0, 0, 0) 98