Sometimes we wish to find the k-th largest element in a list.
Brute Force
A brute force way to do this would be to sort the list and then pick the element.
Quick Select
Another method is quick select, which is based on quick sort:
- Put pivot in its appropriate place. If that place is k, we’re done.
- If it’s not k, recurse only on the sublist that will have k. Keep doing this till we’re done.
On average, this algorithm is Θ(N), although the worst case is still quadratic.
Using Tukey’s ninther guarantees linear run time even in the worst case. However, pivot calculation overhead makes it undesirable.
If one gets the true median each time (Θ(N)), then the algorithm is guaranteed to be Θ(NlgN), but again the overhead is not worth it.
Derivation of Average Complexity
In the “merge” step, we have f(n)≤cn. This is effectively the time to partition.
Let’s say the algorithm is in “phase j” if the sublist it is operating on is of size between (3/4)j+1n and (3/4)jn. Note that at any point, we’ll be in some phase.
Thus, run time is given by:
where Xj is the number of times we are in “phase” j.
Why did we pick 3/4? Because to change phase, we need to partition at least into a 1/4,3/4 split. And the probability of this happening is exactly 1/2.
So what is the expected value of Xj? It is the expected number of times you have to flip a coin to get heads. This in itself satisfies:
where 1 represents the first coin flip, 12 represents the probability you got a tails, and the problem recurses.
Using linearity of expectation:
Evaluating this will give 8cn.
Implementations
Python
from random import shuffle
def simple_compare(left, right):
"""
Simple Comparison using < and >.
Keyword Arguments:
left --
right --
"""
if left < right:
return -1
elif left > right:
return 1
else:
return 0
def partition(lst):
"""
Partition a list based on pivot. This is the 3-way quicksort method
Keyword Arguments:
lst --
"""
pivot = lst[0]
pivot_index = 0
i = 1
j = len(lst) - 1
while (i <= j):
if lst[i] < pivot:
i += 1
continue
if lst[j] > pivot:
j -= 1
continue
lst[i], lst[j] = lst[j], lst[i]
lst[j], lst[0] = lst[0], lst[j]
return j+1, pivot, lst[:j], lst[j+1:]
def quickselect(lst, k, comparator=simple_compare, is_shuffled=False):
"""
Quick select. It selects the kth smallest element. 0 means smallest.
Note that this does /not/ select in place! It creates temporary
memory.
Keyword Arguments:
lst -- List to sort
k -- Which element to select (0 means first)
comparator -- Comparator function. It should take in two arguments, and
return -1 if the left is smaller, 1 if the right is smaller,
and 0 if they are equal. If none is provided, then it will
assume a numerical sort.
is_shuffled -- Flag on whether to shuffle list.
"""
if len(lst) < 2:
return lst[0]
if not is_shuffled:
shuffle(lst)
position, pivot, left, right = partition(lst)
position -= 1
if position == k:
print "Center", position, k, pivot
return pivot
if position < k:
print "Right", position, k, pivot
return quickselect(right, k-position-1, comparator, True)
if position > k:
print "Left", position, k, pivot
return quickselect(left, k, comparator, True)
C++
#include <algorithm>
#include <vector>
#include <cstdlib>
// Pass in the start and stop iterators to the vector.
int quickselect(unsigned int k, std::vector<int>::iterator const & begin,
std::vector<int>::iterator const & end, bool is_shuffled=false);
std::vector<int>::iterator partition(std::vector<int>::iterator const & begin,
std::vector<int>::iterator const & end)
{
int pivot = *begin;
std::vector<int>::iterator i = begin + 1;
std::vector<int>::iterator j = end - 1;
while (std::distance(i, j) >= 0)
{
if (*i < pivot)
{
++i;
continue;
}
if (*j > pivot)
{
--j;
continue;
}
int tmp = *i;
*i = *j;
*j = tmp;
}
int tmp = *begin;
*begin = *j;
*j = tmp;
return j+1;
}
int quickselect(unsigned int k, std::vector<int>::iterator const & begin,
std::vector<int>::iterator const & end, bool is_shuffled)
{
if (std::distance(begin, end) < 2)
{
return *begin;
}
if (!is_shuffled)
{
std::random_shuffle(begin, end);
}
std::vector<int>::iterator position = partition(begin, end);
if (std::distance(begin, position) == k)
{
return *(position - 1);
}
if (std::distance(begin, position) < k)
{
return quickselect(std::distance(position, begin+k), position, end, true);
}
if (std::distance(begin, position) > k)
{
return quickselect(k, begin, position-1, true);
}
}