先写一维的情况,然后改成多维的情况,这样比较好写且容易测试。numpy数组的运算符号与常规算术运算符号相统一,实在是太方便了!
def extract(A, shape, position, fill=0):
R = np.ones(shape) * fill
A_shape = np.array(A.shape)
R_shape = np.array(list(shape))
position = np.array(position)
A_start = position - R_shape // 2
A_stop = A_start + R_shape
R_start = np.maximum(0, -A_start).tolist()
R_stop = (R_shape - np.maximum(0, A_stop - A_shape)).tolist()
A_start = np.maximum(0, A_start).tolist()
A_stop = np.minimum(A_shape, A_stop).tolist()
R_slice = [slice(start, stop) for (start, stop) in zip(R_start, R_stop)]
A_slice = [slice(start, stop) for (start, stop) in zip(A_start, A_stop)]
R_slice = tuple(R_slice)
A_slice = tuple(A_slice)
R[R_slice] = A[A_slice]
return R
A = np.arange(100).reshape((10, 10))
R = extract(A, (5, 5), (1, 1))
print(A)
print(R)
[[ 0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24 25 26 27 28 29]
[30 31 32 33 34 35 36 37 38 39]
[40 41 42 43 44 45 46 47 48 49]
[50 51 52 53 54 55 56 57 58 59]
[60 61 62 63 64 65 66 67 68 69]
[70 71 72 73 74 75 76 77 78 79]
[80 81 82 83 84 85 86 87 88 89]
[90 91 92 93 94 95 96 97 98 99]]
[[ 0. 0. 0. 0. 0.]
[ 0. 0. 1. 2. 3.]
[ 0. 10. 11. 12. 13.]
[ 0. 20. 21. 22. 23.]
[ 0. 30. 31. 32. 33.]]