Topk问题

yuanheci 2024年08月16日 35次浏览

100亿个数中找topK个数?

方法一:堆
n个结点建堆的时间复杂度是O(n)。

如果用堆来解决,直观来说,应该建立大根堆。时间复杂度为O(k + nlogk)。
但是正确的做法是建立小根堆,思路如下:
首先用k个数建立小根堆,然后用剩下的数一个个与堆顶比较。比堆顶大的就和堆顶交换,再调整。最后剩下来的就是top。
时间复杂度:O(k + n
logk)

如果是求topK个最大值,就用小根堆;topK个最小值,就用大根堆。

#include <iostream>
#include <stdlib.h>
#include <time.h>
using namespace std;
const int N = 10;

void down(int u, int* arr, int cnt) {
    int t = u;
    if (2 * u  + 1<= cnt && arr[2 * u + 1] < arr[t]) t = 2 * u + 1;
    if (2 * u  + 2<= cnt && arr[2 * u  + 2] < arr[t]) t = 2 * u + 2;
    if (u != t) {
        swap(arr[t], arr[u]);
        down(t, arr, cnt);
    }
}
void topK(int arr[], int &k, int &n)
{   
    if (k > n) return;
    //建立k个数的小根堆
    for (int i = k / 2; i > 0; i--) { // 下标从1开始,最后一个分支节点为 k/2
        down(i, arr, k);
    }
    //测试堆是否正确
    //for (int i = 1; i <= 10; i++)
    //{
    //    cout << arr[i] << ' ';
    //}
    //cout << endl;
    //然后对底k+1个数开始比较 大于堆顶则交换,再调整
    for (int j = k + 1; j <= n; j++) {
        if (arr[1] < arr[j]) {
            swap(arr[1], arr[j]);;
            down(1, arr, k);
        }
    }

    //最后,数组前k个数就是topK,当然也可以用快排
    for (int i = 1; i <= k; i++) {
        cout << arr[i] << ' ';
    }
    cout << "\n";
}

int main()
{
    int arr[N + 1];
    srand((unsigned int)time(NULL));
    for (int i = 1; i <= N; i++) {

        int a = rand();
        int b = rand();
        arr[i] = a - b;

    }
    for (int i = 1; i <= 10; i++)
    {
        cout << arr[i] << ' ';
    }
    cout << "\n";
    int n = N;
    int k = 3;
    topK(arr, k, n);

}

方法二:快排

虽然是找第k大/小的数,但是前k个数就是topK,只是还没排好序~

/*
分成两个部分:
1. <= 基准值,数量为 cnt
2. >= 基准值
cnt >= k,在左半边找第 k 小数
cnt < k,在右半边找第 k - cnt 小数
*/
#include <iostream>
using namespace std;
const int N = 1e5 + 10;

int quick_sort(int q[], int l, int r, int k) {
    if (l >= r) return q[l];
    int i = l - 1, j = r + 1, x = q[l + r >> 1];
    while (i < j) {
        do i++; while (q[i] < x);
        do j--; while (q[j] > x);
        // 如果是求topK最大值,这里就改成:
        // do i++; while (q[i] > x);
        // do j--; while (q[j] < x);
        if (i < j) swap(q[i], q[j]);
    }
    int cnt = j - l + 1;
    if (cnt >= k)
        return quick_sort(q, l, j, k);
    else
        return quick_sort(q, j + 1, r, k - cnt);
}

int n, k, q[N];
int main() {
    scanf("%d%d", &n, &k);
    for (int i = 0; i < n; i++) scanf("%d", &q[i]);
    printf("%d\n", quick_sort(q, 0, n - 1, k));
    for(int i = 0; i < k; i++) printf("%d ", q[i]);
}

输入:
10 3
4 4 1 5 3 9 22 44 11 6 3

输出:
4
3 1 4   (注意以上代码逻辑不能处理重复的元素)