YNZH's Blog

大名鼎鼎的BFPRT算法、基于堆的TopK算法

一、直接排序

​ 复杂度O( N * log(N)),缺点所有都排序了,然而只需求第K大的数

二、排序K个

​ 既然求topK,我们可以遍历K次每次选择一个最大的值,复杂度O(N*k),缺点:topK的K个数也被排序了

​ 实际上只需要求第K大呀!

三、堆排

​ 维护一个大小为K的小根堆,遍历一次数组,只要大于小根堆堆顶则加入堆中。topK的K个数也不要排序,时间复杂度 O(N * log(K) ).

四、随即选择法(利用快排的思想)

​ 算法步骤大概如下:

  1. 第一次快排分区之后,找到分区点p,则p左边的数小于p位置的数,p右边的数大于p位置,若p正好是第k大的位置,则直接返回;若p小于第K大的位置,则说明第K大的数在p的右半部分,这是问题转化为求p右半部分数组中第K - (左半部分个数)大的问题;若p的位置大于第K大的位置,则说明第K大在p左半部分数组中,问题转化为求p左半部分数组中第K - (p右半部分数组大小)的topK问题。
  2. 递归上述过程即可。
  3. 整体时间复杂度近似是O( N ), 对于海量数据还有一种BFPRT的优化算法。整体的时间复杂度是线性的。

上述第4个随机选择法:(快排版本TOPK(第k个大的数),代码)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public static int topK(int[] arr, int l, int r, int k) {
if (l >= r) return l;
int p = partition(arr, l, r);// [l,p)区间小于pivot, (p,r]区间大于等于pivot
if (r - k + 1 == p) return p;
else if (r - k + 1 < p) return topK(arr, l, p - 1, k - r + p - 1);
else return topK(arr, p + 1, r, k);
}

public static int partition(int[] arr, int l, int r) {
int pivot = arr[l]; //pivot可以随机选择,这是直接默认
int i = l;
while (l < r) {
while (l < r && arr[r] >= pivot) r--;
while (l < r && arr[l] <= pivot) l++;
int tmp = arr[l];
arr[l] = arr[r];
arr[r] = tmp;
}
arr[i] = arr[l];
arr[l] = pivot;
return l;
}

五、BFPRT:

该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,时间复杂度最坏为O(N)

BFPRT算法步骤如下:

  1. 选取基准元素;

    1. 将n个元素每5个一组,分成n/5(上界)组,最后的一个组的元素个数为n%5,有效的组数为n/5。
    2. 取出每一组的中位数,最后一个组的不用计算中位数,任意排序方法,这里的数据比较少只有5个,可以用简单的冒泡排序或是插入排序。
    3. 对于第1.2中找到的所有中位数,调用BFPRT算法求出它们的中位数,作为基准元素,设为x,偶数个中位数的情况下设定为选取中间小的一个。
  2. 以1.3中选取的基准元素作为分割点,将小于基准元素的放在左边,个数为k个,大于或等于基准元素的放在右边,个数为n-k。

  3. 判断基准元素位置i与k的大小

    1. 如果i==k,返回x;
    2. 如果i<k,在小于x的元素中递归查找第i小的元素;
    3. 如果i>k,在大于等于x的元素中递归查找第i-k小的元素。

因为每个基准元素都会选择中位数,所以快排每次会对半分区间,所以时间复杂度是最好的。

代码实现:来源网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
public class BFPRT {
public static void main(String[] args) {
int[] arr = {6, 9, 1, 3, 1, 2, 2, 5, 6, 1, 3, 5, 9, 7, 2, 5, 6, 1, 9};
// sorted : { 1, 1, 1, 1, 2, 2, 2, 3, 3, 5, 5, 5, 6, 6, 6, 7, 9, 9, 9 }
printArray(getMinKNumsByHeap(arr, 10));//通过堆排方式得到Top-K元素
printArray(getMinKNumsByQuick(arr, 10));//通过快排方式得到Top-K元素
printArray(getMinKNumsByBFPRT(arr, 10));//通过BFPRT算法得到Top-K元素
}

/**
* 堆排解法,时间复杂度为O(N*logK)
*
* @param arr
* @param i
* @return
*/
private static int[] getMinKNumsByHeap(int[] arr, int k) {
if (k < 1 || k > arr.length) {
return arr;
}
int[] kHeap = new int[k];
for (int i = 0; i != k; i++) {
heapInsert(kHeap, arr[i], i);
}
for (int i = k; i != arr.length; i++) {
if (arr[i] < kHeap[0]) {
kHeap[0] = arr[i];
heapify(kHeap, 0, k);
}
}
return kHeap;
}

private static void heapInsert(int[] arr, int value, int index) {
arr[index] = value;
while (index != 0) {
int parent = (index - 1) / 2;
if (arr[parent] < arr[index]) {
swap(arr, parent, index);
index = parent;
} else {
break;
}
}
}

private static void heapify(int[] arr, int index, int heapSize) {
int left = index * 2 + 1;
int right = index * 2 + 2;
int largest = index;
while (left < heapSize) {
if (arr[left] > arr[index]) {
largest = left;
}
if (right < heapSize && arr[right] > arr[largest]) {
largest = right;
}
if (largest != index) {
swap(arr, largest, index);
} else {
break;
}
index = largest;
left = index * 2 + 1;
right = index * 2 + 2;
}
}

/**
* 通过快排的方式,时间复杂度为O(N)
*
* @param arr
* @param k
* @return
*/
private static int[] getMinKNumsByQuick(int[] arr, int k) {
if (arr != null && arr.length > 0) {
int low = 0;
int high = arr.length - 1;
int index = partition(arr, low, high);
//不断调整分治思想,直到position=k-1
while (index != k - 1) {
//大了,往前调整
if (index > k - 1) {
high = index - 1;
index = partition(arr, low, high);
}
//小了,往后调整
if (index < k - 1) {
low = index + 1;
index = partition(arr, low, high);
}
}
}
int[] res = new int[k];
for (int i = 0; i < res.length; i++) {
res[i] = arr[i];
}
return res;
}

private static int partition(int[] arr, int low, int high) {
if (arr != null && low < high) {
int flag = arr[low];
while (low < high) {
while (low < high && arr[high] >= flag) {
high--;
}
arr[low] = arr[high];
while (low < high && arr[low] <= flag) {
low++;
}
arr[high] = arr[low];
}
arr[low] = flag;
return low;
}
return 0;
}

/**
* 通过BRPRT算法获得Top-K问题的解,时间复杂度为O(N)
*
* @param arr
* @param k
*/
private static int[] getMinKNumsByBFPRT(int[] arr, int k) {
if (k < 1 || k > arr.length) {
return arr;
}
int minKth = getMinKthByBFPRT(arr, k);
int[] res = new int[k];
int index = 0;
for (int i = 0; i < arr.length; i++) {
if (arr[i] < minKth) {
res[index++] = arr[i];
}
}
for (; index < res.length; index++) {
res[index] = minKth;
}
return res;
}

private static int getMinKthByBFPRT(int[] arr, int k) {
int[] copyArr = copyArray(arr);
return select(copyArr, 0, copyArr.length - 1, k - 1);
}

private static int[] copyArray(int[] arr) {
int[] res = new int[arr.length];
for (int i = 0; i < res.length; i++) {
res[i] = arr[i];
}
return res;
}

private static int select(int[] arr, int begin, int end, int i) {
if (begin == end) {
return arr[begin];
}
int pivot = medianOfMedians(arr, begin, end);
int[] pivotRange = partition(arr, begin, end, pivot);
if (i >= pivotRange[0] && i <= pivotRange[1]) {
return arr[i];
} else if (i < pivotRange[0]) {
return select(arr, begin, pivotRange[0] - 1, i);
} else {
return select(arr, pivotRange[1] + 1, end, i);
}
}

private static int medianOfMedians(int[] arr, int begin, int end) {
int num = end - begin + 1;
int offset = num % 5 == 0 ? 0 : 1;
int[] mArr = new int[num / 5 + offset];
for (int i = 0; i < mArr.length; i++) {
int beginI = begin + i * 5;
int endI = beginI + 4;
mArr[i] = getMedian(arr, beginI, Math.min(end, endI));
}
return select(mArr, 0, mArr.length - 1, mArr.length / 2);
}

private static int[] partition(int[] arr, int begin, int end, int pivotValue) {
int small = begin - 1;
int cur = begin;
int big = end + 1;
while (cur != big) {
if (arr[cur] < pivotValue) {
swap(arr, ++small, cur++);
} else if (arr[cur] > pivotValue) {
swap(arr, cur, --big);
} else {
cur++;
}
}
int[] range = new int[2];
range[0] = small + 1;
range[1] = big - 1;
return range;
}

private static int getMedian(int[] arr, int begin, int end) {
insertionSort(arr, begin, end);
int sum = end + begin;
int mid = (sum / 2) + (sum % 2);
return arr[mid];
}

private static void insertionSort(int[] arr, int begin, int end) {
for (int i = begin + 1; i != end + 1; i++) {
for (int j = i; j != begin; j--) {
if (arr[j - 1] > arr[j]) {
swap(arr, j - 1, j);
} else {
break;
}
}
}
}

/**
* 公共方法,交换数据和打印数据
* @param arr
* @param i
* @param j
*/
private static void swap(int[] arr, int i, int j) {
int temp = arr[i];
arr[i] = arr[j];
arr[j] = temp;
}

private static void printArray(int[] arr) {
for (int i = 0; i < arr.length; i++) {
System.out.print(arr[i] + " ");
}
System.out.println();
}
}

 评论


博客内容遵循 署名-非商业性使用-相同方式共享 4.0 国际 (CC BY-NC-SA 4.0) 协议

本站使用 Material X 作为主题 , 总访问量为 次 。
载入天数...载入时分秒...