寻找两个有序数组的中位数
描述
困难
给定两个大小为 m 和 n 的有序数组 nums1 和 nums2。
请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。
你可以假设 nums1 和 nums2 不会同时为空。
示例
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
nums1 = [1, 2]
nums2 = [3, 4]
则中位数是 (2 + 3)/2 = 2.5
解题
中位数是一串数据中间的数,如果这串数据长是奇数,那么就是中间的数,如果是偶数,中位数是中间两个数的平均数
设有数据
索引 | 0 | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|---|
数组1 nums1 | 3 | 5 | 8 | 9 | ||
数组2 nums2 | 1 | 2 | 7 | 10 | 11 | 12 |
其中 m=4, n=6, size=4+6=10
该中位数是两个数组合并排好序后中间两个数的平均值
则在合并后数组的中位数左右的数据个数是确定,是所有元素个数的一半,即size/2=5个
假设在数组1中切一刀,将数组1分成两部分,如在3和5之间,则左边元素个数为1,右边个数为3
所以,在数组2中的一刀应切在10和11之间,使得数组2左边元素个数为4,右边个数为2
将数组1左边的元素记作left_1, 右边的记作right_1
将数组2左边的元素记作left_2, 右边的记作right_2
则left_1和left_2组成了中位数左边的部分
right_1和right_2组成了中位数右边的部分
即
left 1 2 3 7 10
right 5 8 9 11 12
显然有问题
假设这一刀的位置是正确的,那么
left_1 <= right_2
left_2 <= right_1
这样就能确保最终左边的元素小于右边的元素
所以,只需在数组1中找出这一刀的正确位置
使用二分查找,将数组1切割的位置记为cut1, 将数组2切割的位置记为cut2, cut2=(size/2)-cut1
cut1,cut2表示nums1,nums2左边元素的个数,cut1=1表示切在3和5之间,cut2=4表示切在10和11之间
切割情况:
- left_1 > right_2 cut1左移,使数组1中更多的数被分配到右边
- left_2 > right_1 cut1右移,使数组1中更多的数被分配到左边
- 其他 (left_1<=right_2, left_2<=right_1) 即cut1的位置正确,停止程序
注意
-
存在一些边界条件,cut的位置在边缘,即cut1=0或m,使得nums1全在合并后数组的右边或左边。将min和max分别加在nums1和nums2两端,就可以统一考虑,当cut在边缘时,直接输出即可
-
使用二分查找,不是简单的cut–或cut++,而是需要将cut1的区域分为记下来,用[cutL, cutR]表示,一开的范围为[cutL, cutR] = [0, m]
- 当left_1 > right_2时,cut1左移,cut1范围变成[cutL, cut1-1],下次cut1的位置为cut1=(cutR-cutL)/2+cutL
- 当left_2 > right_1时,cut1右移,cut1范围变成[cut1+1, cutR],下次cut1的位置为cut1=(cutR-cutL)/2+cutL
-
当元素个数为奇数时,中间的元素为min(right_1, right_2)直接输出
-
代码while中的left_1, left_2, right_1, right_2
如果cut1=1
索引 0 1 2 3 4 5 cut1 数组1 nums1 3 5 8 9 数组2 nums2 1 2 7 10 11 12 cut2 所以
left_1 = 3
left_2 = 10
right_1 = 5
right_2 = 11
此时left_2>right_1,cut1右移如果cut1在边界上,如
索引 0 1 2 3 4 5 cut1 数组1 nums1 10 11 12 13 数组2 nums2 1 2 3 4 5 12 cut2 所以
left_1 = min_value
left_2 = 5
right_1 = 10
right_2 = 12
可以直接输出结果,即(5+10)/2
借鉴leetcode大佬的代码
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
if len(nums1)>len(nums2): # 确保nums1短
nums1, nums2 = nums2, nums1
m, n = len(nums1), len(nums2)
if m == 0:
return (nums2[n//2]+nums2[(n-1)//2]) / 2
size = m+n
cutL, cutR = 0, m
cut1 = m // 2
cut2 = size // 2 - cut1
min_value = -2**31
max_value = 2**31
while True:
cut1 = (cutR-cutL)//2 + cutL
cut2 = size // 2 - cut1
# 解决边界问题
left_1 = nums1[cut1-1] if cut1 !=0 else min_value
left_2 = nums2[cut2-1] if cut2 !=0 else min_value
right_1 = nums1[cut1] if cut1 != m else max_value
right_2 = nums2[cut2] if cut2 != n else max_value
if left_1 > right_2:
cutR = cut1-1
elif left_2 > right_1:
cutL = cut1+1
else:
if size%2 == 0:
left = left_1 if left_1 > left_2 else left_2
right = right_1 if right_1 < right_2 else right_2
return (left+right)/2
else:
return right_1 if right_1 < right_2 else right_2