【堆Heap】二叉堆 BinaryHeap 简介及源码实现、最小堆解决 TOPK 问题

在此致谢小码哥的恋上数据结构,堪称经典中的经典。
在此致谢小码哥的恋上数据结构,堪称经典中的经典。
在此致谢小码哥的恋上数据结构,堪称经典中的经典。

堆(Heap)

堆的出现

设计一种数据结构,用来存放整数,要求提供3个接口

  • 添加元素
  • 获取最大值
  • 删除最大值

用已学过的数据结构来对比一下时间复杂度:
在这里插入图片描述
为了迎合此需求出现了一种新的数据结构 ——

  • 获取最大值: O ( l o g n ) O(logn)
  • 删除最大值: O ( l o g n ) O(logn)
  • 添加元素: O ( l o g n ) O(logn)

堆简介

(Heap)是一种树状的数据结构,常见的堆实现有

  • 二叉堆(Binary Heap,完全二叉堆
  • 多叉堆(D-heap、D-ary Heap)
  • 索引堆(Index Heap)
  • 二项堆(Binomial Heap)
  • 斐波那契堆(Fibonacci Heap)
  • 左倾堆(Leftist Heap,左式堆
  • 斜堆(Skew Heap)

堆的重要性质:任意节点的值总是 \geq \leqslant ) 子节点的值

  • 如果任意节点的值总是 \geq 子节点的值,称为:最大堆、大根堆、大顶堆
  • 如果任意节点的值总是 \leqslant 子节点的值,称为:最小堆、小根堆、小顶堆

堆中的元素必须具备可比较性(跟二叉搜索树一样)
在这里插入图片描述

二叉堆(Binary Heap)

二叉堆的逻辑结构就是一棵完全二叉树,所以也叫完全二叉堆

鉴于完全二叉树的一些特性,二叉堆的底层(物理结构)一般用数组实现即可

在这里插入图片描述
索引 i 的规律(n 是元素数量)

  • 如果 i = 0,它是节点
  • 如果 i > 0,它的节点的索引为 floor( (i - 1) / 2 )
  • 如果 2 i + 1 n 1 2i + 1 \leqslant n - 1 ,它的子节点的索引为 2 i + 1 2i + 1
  • 如果 2 i + 1 > n 1 2i + 1 >n - 1 ,它无左子节点
  • 如果 2 i + 2 n 1 2i + 2 \leqslant n - 1 ,它的子节点的索引为 2 i + 1 2i + 1
  • 如果 2 i + 2 > n 1 2i + 2 > n - 1 ,它无右子节点

获取最大值

主要就是检测数组是否不能为空,为空则抛出异常。

public E get() {
	emptyCheck();
	return elements[0];
}
private void emptyCheck() {
	if (size == 0) {
		throw new IndexOutOfBoundsException("Heap is empty");
	}
}

最大堆 — 添加

思路是:

  • 将添加的节点放入数组最后

  • 循环执行上滤操作(Sift Up),即:

    • 如果 node > 父节点
      与父节点交换位置
    • 如果node <= 父节点,或者 node 没有父节点
      退出循环
  • 时间复杂度: O ( l o g n ) O(logn)

在这里插入图片描述在这里插入图片描述

/**
 * 最大堆 — 添加
 */
public void add(E element) {
	// 检测传入的元素非空
	elementNotNullCheck(element);
	// 扩容操作,确保容量大于当前元素个数大小
	ensureCapacity(size + 1);
	// 将要添加的元素放到数组最后
	elements[size++] = element;
	// 上滤操作
	siftUp(size - 1);
}
/**
 * 让index位置的元素上滤
 */
private void siftUp(int index) {
	E e = elements[index]; // 要添加的节点
	while (index > 0) { // 直到比较到根结点
		int pindex = (index - 1) >> 1; // 父节点索引
		E p = elements[pindex]; // 添加节点的父节点
		if (compare(e, p) <= 0) return;
		
		// 交换index、pindex位置的内容
		E tmp = elements[index];
		elements[index] = elements[pindex];
		elements[pindex] = tmp;
		
		// 重新赋值index
		index = pindex;
	}
}

最大堆 — 添加优化

和上面的区别在于:

  • 并不是每次节点的值 > 父节点的值就直接交换
  • 而是将新添加的节点备份,确定最终位置才放上去

在这里插入图片描述

/**
 * 最大堆 — 添加
 */
public void add(E element) {
	// 检测传入的元素非空
	elementNotNullCheck(element);
	// 扩容操作,确保容量大于当前元素个数大小
	ensureCapacity(size + 1);
	// 将要添加的元素放到数组最后
	elements[size++] = element;
	// 上滤操作
	siftUp(size - 1);
}
/**
 * 让index位置的元素上滤
 */
private void siftUp(int index) {

	E element = elements[index];
	while (index > 0) {
		// 父节点索引 = (子节点索引-1) / 2
		int parentIndex = (index - 1) >> 1;
		E parent = elements[parentIndex];
		if (compare(element, parent) <= 0) break;
		
		// 将父元素存储在index位置
		elements[index] = parent;
		
		// 重新赋值index
		index = parentIndex;
	}
	elements[index] = element;
}

最大堆 — 删除

思路是:

  • 用最后一个节点覆盖根结点
  • 删除最后一个节点
  • 循环执行下滤操作(Sift Down),即:
    • 如果 node < 最大的子节点
      与最大的子节点交换位置
    • 如果 node >= 最大的子节点,或者 node 没有子节点
      退出循环
  • 时间复杂度: O ( l o g n ) O(logn)
  • 交换位置的操作可以像添加那样进行优化

在这里插入图片描述
在这里插入图片描述

/**
 * 最大堆—删除堆顶元素
 */
public E remove() {
	// 检测数组不能为空
	emptyCheck(); 
	// 获取数组最后元素的索引
	int lastIndex = --size;
	// 获取要删除元素的节点
	E root = elements[0];
	elements[0] = elements[lastIndex];
	elements[lastIndex] = null;
	
	siftDown(0); // 下滤操作
	return root;
}
/**
 * 让index位置的元素下滤
 */
private void siftDown(int index) {
	E element = elements[index];
	int half = size >> 1; // 非叶子节点的数量
	// 第一个叶子节点的索引 == 非叶子节点的数量
	// index < 第一个叶子节点的索引
	// 必须保证index位置是非叶子节点
	while (index < half) { 
		// index的节点有2种情况
		// 1.只有左子节点
		// 2.同时有左右子节点
		
		// 默认为左子节点跟它进行比较
		int childIndex = (index << 1) + 1;
		E child = elements[childIndex];
		
		// 右子节点
		int rightIndex = childIndex + 1;
		
		// 选出左右子节点最大的那个
		if (rightIndex < size && compare(elements[rightIndex], child) > 0) {
			child = elements[childIndex = rightIndex];
		}
		
		if (compare(element, child) >= 0) break;

		// 将子节点存放到index位置
		elements[index] = child;
		// 重新设置index
		index = childIndex;
	}
	elements[index] = element;
}

replace

public E replace(E element) {
	elementNotNullCheck(element);
	
	E root = null;
	if (size == 0) {
		elements[0] = element;
		size++;
	} else {
		root = elements[0];
		elements[0] = element;
		siftDown(0);
	}
	return root;
}

最大堆 — 批量建堆(Heapify)

  • 自上而下的上滤
  • 自下而上的下滤

自上而下的上滤

在这里插入图片描述

自下而上的下滤

在这里插入图片描述

效率对比

  • 自上而下的上滤时间复杂度: O ( n l o g n ) O(nlogn)
  • 自下而上的下滤时间复杂度: O ( n l o g k ) O(nlogk)
    在这里插入图片描述

二叉堆源码

堆的基本接口 Heap.java

public interface Heap<E> {
	int size();	// 元素的数量
	boolean isEmpty();	// 是否为空
	void clear();	// 清空
	void add(E element);	 // 添加元素
	E get();	// 获得堆顶元素
	E remove(); // 删除堆顶元素
	E replace(E element); // 删除堆顶元素的同时插入一个新元素
}

抽象类 AbstractHeap.java

@SuppressWarnings("unchecked")
public abstract class AbstractHeap<E> implements Heap<E> {
	protected int size;
	protected Comparator<E> comparator;
	
	public AbstractHeap(Comparator<E> comparator) {
		this.comparator = comparator;
	}
	
	public AbstractHeap() {
		this(null);
	}
	
	@Override
	public int size() {
		return size;
	}

	@Override
	public boolean isEmpty() {
		return size == 0;
	}
	
	protected int compare(E e1, E e2) {
		return comparator != null ? comparator.compare(e1, e2) 
				: ((Comparable<E>)e1).compareTo(e2);
	}
}

二叉堆 BinaryHeap.java

/**
 * 二叉堆(最大堆)
 */
@SuppressWarnings("unchecked")
public class BinaryHeap<E> extends AbstractHeap<E> {
	private E[] elements;
	private static final int DEFAULT_CAPACITY = 10;
	
	public BinaryHeap(E[] elements, Comparator<E> comparator)  {
		super(comparator);
		
		if (elements == null || elements.length == 0) {
			this.elements = (E[]) new Object[DEFAULT_CAPACITY];
		} else {
			size = elements.length;
			int capacity = Math.max(elements.length, DEFAULT_CAPACITY);
			this.elements = (E[]) new Object[capacity];
			for (int i = 0; i < elements.length; i++) {
				this.elements[i] = elements[i];
			}
			heapify();
		}
	}
	
	public BinaryHeap(E[] elements)  {
		this(elements, null);
	}
	
	public BinaryHeap(Comparator<E> comparator) {
		this(null, comparator);
	}
	
	public BinaryHeap() {
		this(null, null);
	}

	@Override
	public void clear() {
		for (int i = 0; i < size; i++) {
			elements[i] = null;
		}
		size = 0;
	}

	@Override
	public void add(E element) {
		elementNotNullCheck(element);
		ensureCapacity(size + 1);
		elements[size++] = element;
		siftUp(size - 1);
	}

	@Override
	public E get() {
		emptyCheck();
		return elements[0];
	}

	/**
	 * 删除堆顶元素
	 */
	public E remove() {
		emptyCheck();
		
		int lastIndex = --size;
		E root = elements[0];
		elements[0] = elements[lastIndex];
		elements[lastIndex] = null;
		
		siftDown(0);
		return root;
	}

	@Override
	public E replace(E element) {
		elementNotNullCheck(element);
		
		E root = null;
		if (size == 0) {
			elements[0] = element;
			size++;
		} else {
			root = elements[0];
			elements[0] = element;
			siftDown(0);
		}
		return root;
	}
	
	/**
	 * 批量建堆
	 */
	private void heapify() {
		// 自上而下的上滤
//		for (int i = 1; i < size; i++) {
//			siftUp(i);
//		}
		
		// 自下而上的下滤
		for (int i = (size >> 1) - 1; i >= 0; i--) {
			siftDown(i);
		}
	}
	
	/**
	 * 让index位置的元素下滤
	 * @param index
	 */
	private void siftDown(int index) {
		E element = elements[index];
		int half = size >> 1; // 非叶子节点的数量
		// 第一个叶子节点的索引 == 非叶子节点的数量
		// index < 第一个叶子节点的索引
		// 必须保证index位置是非叶子节点
		while (index < half) { 
			// index的节点有2种情况
			// 1.只有左子节点
			// 2.同时有左右子节点
			
			// 默认为左子节点跟它进行比较
			int childIndex = (index << 1) + 1;
			E child = elements[childIndex];
			
			// 右子节点
			int rightIndex = childIndex + 1;
			
			// 选出左右子节点最大的那个
			if (rightIndex < size && compare(elements[rightIndex], child) > 0) {
				child = elements[childIndex = rightIndex];
			}
			
			if (compare(element, child) >= 0) break;

			// 将子节点存放到index位置
			elements[index] = child;
			// 重新设置index
			index = childIndex;
		}
		elements[index] = element;
	}
	
	/**
	 * 让index位置的元素上滤
	 */
	private void siftUp(int index) {
//		E e = elements[index];
//		while (index > 0) {
//			int pindex = (index - 1) >> 1;
//			E p = elements[pindex];
//			if (compare(e, p) <= 0) return;
//			
//			// 交换index、pindex位置的内容
//			E tmp = elements[index];
//			elements[index] = elements[pindex];
//			elements[pindex] = tmp;
//			
//			// 重新赋值index
//			index = pindex;
//		}
		E element = elements[index];
		while (index > 0) {
			// 父节点索引 = (子节点索引-1) / 2
			int parentIndex = (index - 1) >> 1;
			E parent = elements[parentIndex];
			if (compare(element, parent) <= 0) break;
			
			// 将父元素存储在index位置
			elements[index] = parent;
			
			// 重新赋值index
			index = parentIndex;
		}
		elements[index] = element;
	}
	
	private void ensureCapacity(int capacity) {
		int oldCapacity = elements.length;
		if (oldCapacity >= capacity) return;
		
		// 新容量为旧容量的1.5倍
		int newCapacity = oldCapacity + (oldCapacity >> 1);
		E[] newElements = (E[]) new Object[newCapacity];
		for (int i = 0; i < size; i++) {
			newElements[i] = elements[i];
		}
		elements = newElements;
	}
	
	private void emptyCheck() {
		if (size == 0) {
			throw new IndexOutOfBoundsException("Heap is empty");
		}
	}
	
	private void elementNotNullCheck(E element) {
		if (element == null) {
			throw new IllegalArgumentException("element must not be null");
		}
	}
}

构建一个最小堆

在写完最大堆以后,实现最小堆不需要修改源代码,只需要在创建堆时,传入与最大堆比较方式相反的比较器即可。

public static void main(String[] args) {
	Integer[] data = {88, 44, 53, 41, 16, 6, 70, 18, 85, 98, 81, 23, 36, 43, 37};
	BinaryHeap<Integer> heap = new BinaryHeap<>(data, new Comparator<Integer>() {
		public int compare(Integer o1, Integer o2) {
			return o2 - o1; // 与最大堆比较方式相反
		}
	});
}

TOP K问题

什么是 TopK 问题

  • 从 n 个整数中,找出最大的前 k 个数(k << n)
  • 例如:从100万个整数中找出最大的100个整数

TopK 问题的解法之一:可以用数据结构 “” 来解决

如果使用排序算法进行全排序,需要 O ( n l o g n ) O(nlogn) 的时间复杂度

如果使用二叉堆来解决,可以使用 O ( n l o g k ) O(nlogk) 的时间复杂度来解决

  • 新建一个小顶堆
  • 扫描 n 个整数
    • 先将遍历到的前 k 个数放入堆中
    • 从第 k+1 个数开始,如果大于堆顶元素,就使用 replace 操作
      (删除堆顶元素,将第k+1个数添加到堆中)
  • 扫描完毕后,堆中剩下的就是最大的前 k 个数
public static void main(String[] args) {
	// 新建一个小顶堆
	BinaryHeap<Integer> heap = new BinaryHeap<>(new Comparator<Integer>() {
		public int compare(Integer o1, Integer o2) {
			return o2 - o1;
		}
	});
	
	// 找出最大的前k个数
	int k = 3;
	Integer[] data = {51, 30, 39, 92, 74, 25, 16, 93, 
			91, 19, 54, 47, 73, 62, 76, 63, 35, 18, 
			90, 6, 65, 49, 3, 26, 61, 21, 48};
	for (int i = 0; i < data.length; i++) {
		if (heap.size() < k) { // 前k个数添加到小顶堆
			heap.add(data[i]); // logk
		} else if (data[i] > heap.get()) { // 如果是第k + 1个数,并且大于堆顶元素
			heap.replace(data[i]); // logk
		}
	}
	// O(nlogk)
}

如果是找出最小的前k个数呢?

  • 用大顶堆
  • 如果小于堆顶元素,就使用 replace 操作
发布了170 篇原创文章 · 获赞 47 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/weixin_43734095/article/details/104866058