2.7 花哨的索引
import numpy as np
rand = np.random.RandomState(42)
x = rand.randint(100, size=10)
print(x)
[51 92 14 71 60 20 82 86 74 74]
[x[3], x[7], x[2]]
[71, 86, 14]
ind = [3, 7, 2]
x[ind]
array([71, 86, 14])
利用花哨的索引,结果的形状与索引数组一致,而不是与被索引数组的形状一致。
ind = np.array([[3, 7], [4, 5]])
x[ind]
array([[71, 86], [60, 20]])
X = np.arange(12).reshape((3, 4))
X
array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]])
二维索引,对应的是行和列的索引,如果索引的维度不同,会广播后再索引。
row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
X[row, col]
array([ 2, 5, 11])
X[row[:, np.newaxis], col] # 索引是3X1和1x3,先广播再索引
array([[ 2, 1, 3], [ 6, 5, 7], [10, 9, 11]])
row[:, np.newaxis], col
(array([[0], [1], [2]]), array([2, 1, 3]))
X[2, [2, 0, 1]] # 组合使用,与简单索引
array([10, 8, 9])
X[1:, [2, 0, 1]] # 组合使用,与切片
array([[ 6, 4, 5], [10, 8, 9]])
mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask] # 组合使用,与掩码
array([[ 0, 2], [ 4, 6], [ 8, 10]])
示例:选择随机点
花哨的索引的常见用途是从一个矩阵中选择行的子集,如有一个 N×D
的矩阵,表示在 D 个维度中的 N
个点。以下是一个二维正态分布的点组成的数组:
mean = [0, 0]
cov = [[1, 2], [2, 5]]
X = rand.multivariate_normal(mean, cov, 100)
X.shape
(100, 2)
该数组为100行2列的二维数组,画出散点:
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set()
plt.scatter(X[:, 0], X[:, 1]);
X # 二维数组的内容
array([[-0.644508 , -0.46220608], [ 0.7376352 , 1.21236921], [ 0.88151763, 1.12795177], [ 2.04998983, 5.97778598], [-0.1711348 , -2.06258746], [ 0.67956979, 0.83705124], [ 1.46860232, 1.22961093], [ 0.35282131, 1.49875397], [-2.51552505, -5.64629995], [ 0.0843329 , -0.3543059 ], [ 0.19199272, 1.48901291], [-0.02566217, -0.74987887], [ 1.00569227, 2.25287315], [ 0.49514263, 1.18939673], [ 0.0629872 , 0.57349278], [ 0.75093031, 2.99487004], [-3.0236127 , -6.00766046], [-0.53943081, -0.3478899 ], [ 1.53817376, 1.99973464], [-0.50886808, -1.81099656], [ 1.58115602, 2.86410319], [ 0.99305043, 2.54294059], [-0.87753796, -1.15767204], [-1.11518048, -1.87508012], [ 0.4299908 , 0.36324254], [ 0.97253528, 3.53815717], [ 0.32124996, 0.33137032], [-0.74618649, -2.77366681], [-0.88473953, -1.81495444], [ 0.98783862, 2.30280401], [-1.2033623 , -2.04402725], [-1.51101746, -3.2818741 ], [-2.76337717, -7.66760648], [ 0.39158553, 0.87949228], [ 0.91181024, 3.32968944], [-0.84202629, -2.01226547], [ 1.06586877, 0.95500019], [ 0.44457363, 1.87828298], [ 0.35936721, 0.40554974], [-0.90649669, -0.93486441], [-0.35790389, -0.52363012], [-1.33461668, -3.03203218], [ 0.02815138, 0.79654924], [ 0.37785618, 0.51409383], [-1.06505097, -2.88726779], [ 2.32083881, 5.97698647], [ 0.47605744, 0.83634485], [-0.35490984, -1.03657119], [ 0.57532883, -0.79997124], [ 0.33399913, 2.32597923], [ 0.6575612 , -0.22389518], [ 1.3707365 , 2.2348831 ], [ 0.07099548, -0.29685467], [ 0.6074983 , 1.47089233], [-0.34226126, -1.10666237], [ 0.69226246, 1.21504303], [-0.31112937, -0.75912097], [-0.26888327, -1.89366817], [ 0.42044896, 1.85189522], [ 0.21115245, 2.00781492], [-1.83106042, -2.91352836], [ 0.7841796 , 1.97640753], [ 0.10259314, 1.24690575], [-1.91100558, -3.66800923], [ 0.13143756, -0.07833855], [-0.1317045 , -1.64159158], [-0.14547282, -1.34125678], [-0.51172373, -1.40960773], [ 0.69758045, 0.72563649], [ 0.11677083, 0.88385162], [-1.16586444, -2.24482237], [-2.23176235, -2.63958101], [ 0.37857234, 0.69112594], [ 0.87475323, 3.400675 ], [-0.86864365, -3.03568353], [-1.03637857, -1.18469125], [-0.53334959, -0.37039911], [ 0.30414557, -0.5828419 ], [-1.47656656, -2.13046298], [-0.31332021, -1.7895623 ], [ 1.12659538, 1.49627535], [-1.19675798, -1.51633442], [-0.75210154, -0.79770535], [ 0.74577693, 1.95834451], [ 1.56094354, 2.9330816 ], [-0.72009966, -1.99780959], [-1.32319163, -2.61218347], [-2.56215914, -6.08410838], [ 1.31256297, 3.13143269], [ 0.51575983, 2.30284639], [ 0.01374713, -0.11539344], [-0.16863279, 0.39422355], [ 0.12065651, 1.13236323], [-0.83504984, -2.38632016], [ 1.05185885, 1.98418223], [-0.69144553, -1.56919875], [-1.2567603 , -1.125898 ], [ 0.09619333, -0.64335574], [-0.99658689, -2.35038099], [-1.21405259, -1.77693724]])
X[0] # 二维数组中第0个元素
array([-0.644508 , -0.46220608])
X[0, 0] # 二维数组中第0个元素的横坐标
-0.6445079962363565
X[:, 0] # 二维数组中元素的横坐标组成的数组
array([-0.644508 , 0.7376352 , 0.88151763, 2.04998983, -0.1711348 , 0.67956979, 1.46860232, 0.35282131, -2.51552505, 0.0843329 , 0.19199272, -0.02566217, 1.00569227, 0.49514263, 0.0629872 , 0.75093031, -3.0236127 , -0.53943081, 1.53817376, -0.50886808, 1.58115602, 0.99305043, -0.87753796, -1.11518048, 0.4299908 , 0.97253528, 0.32124996, -0.74618649, -0.88473953, 0.98783862, -1.2033623 , -1.51101746, -2.76337717, 0.39158553, 0.91181024, -0.84202629, 1.06586877, 0.44457363, 0.35936721, -0.90649669, -0.35790389, -1.33461668, 0.02815138, 0.37785618, -1.06505097, 2.32083881, 0.47605744, -0.35490984, 0.57532883, 0.33399913, 0.6575612 , 1.3707365 , 0.07099548, 0.6074983 , -0.34226126, 0.69226246, -0.31112937, -0.26888327, 0.42044896, 0.21115245, -1.83106042, 0.7841796 , 0.10259314, -1.91100558, 0.13143756, -0.1317045 , -0.14547282, -0.51172373, 0.69758045, 0.11677083, -1.16586444, -2.23176235, 0.37857234, 0.87475323, -0.86864365, -1.03637857, -0.53334959, 0.30414557, -1.47656656, -0.31332021, 1.12659538, -1.19675798, -0.75210154, 0.74577693, 1.56094354, -0.72009966, -1.32319163, -2.56215914, 1.31256297, 0.51575983, 0.01374713, -0.16863279, 0.12065651, -0.83504984, 1.05185885, -0.69144553, -1.2567603 , 0.09619333, -0.99658689, -1.21405259])
用花哨的索引选择随机而不重复的20个索引值,并用这些索引值选择原始数组对应的值:
indices = np.random.choice(X.shape[0], 20, replace=False)
indices
array([94, 76, 22, 0, 77, 36, 32, 58, 54, 70, 50, 92, 44, 38, 65, 46, 79, 68, 67, 71])
selection = X[indices] # 花哨的索引
selection.shape
(20, 2)
plt.scatter(X[:, 0], X[:, 1], alpha=0.3)
plt.scatter(selection[:, 0], selection[:, 1], facecolor='none', edgecolor='b', s=200);
用花哨的索引修改值
x = np.arange(10)
i = np.array([2, 1, 8, 4])
x[i] = 99
x
array([ 0, 99, 99, 3, 99, 5, 6, 7, 99, 9])
x[i] -= 10 # 赋值语句
x
array([ 0, 89, 89, 3, 89, 5, 6, 7, 89, 9])
x[[0, 0]] # 索引是个数组,依次索引0和0,相当于索引第0个值两次
array([0, 0])
x[[0, 0]] = [4, 6] # 重复索引,赋值的4会被6覆盖
x
array([ 6, 89, 89, 3, 89, 5, 6, 7, 89, 9])