tags:

views:

150

answers:

2

I currently have a long list which is being sorted using a lambda function f. I then choose a random element from the first five elements. Something like:

f = lambda x: some_function_of(x, local_variable)
my_list.sort(key=f)
foo = choice(my_list[:4])

This is a bottleneck in my program, according to the profiler. How can I speed things up? Is there a fast, inbuilt way to retrieve the elements I want (in theory shouldn't need to sort the whole list). Thanks.

+6  A: 

Use heapq.nlargest or heapq.nsmallest.

For example:

import heapq

elements = heapq.nsmallest(4, my_list, key=f)
foo = choice(elements)

This will take O(N+KlogN) time (where K is the number of elements returned, and N is the list size), which is faster than O(NlogN) for normal sort when K is small relative to N.

interjay
Hmm. So far this is in fact marginally slower. N is 8000 and K is 5.
Sort Me Out Please
It's possible that the bottleneck is the N calls to some_function_of and the sort is much faster in comparison, in which case there isn't much you can do except improve that function. Another possibility is that the data is nearly sorted already, in which case Python's sort will be very fast.
interjay
You're probably right. Will stick with heapq.nsmallest for now as it conveys intent. Thanks.
Sort Me Out Please
+1  A: 

It's actually possible in linear time (O(N)) on average.

You need a partition algorithm:

def partition(seq, pred, start=0, end=-1):
    if end == -1: end = len(seq)
    while True:
        while True:
            if start == end: return start
            if not pred(seq[start]): break
            start += 1
        while True:
            if pred(seq[end-1]): break
            end -= 1
            if start == end: return start
        seq[start], seq[end-1] = seq[end-1], seq[start]
        start += 1
        end -= 1

which can be used by an nth_element algorithm:

def nth_element(seq_in, n, key=lambda x:x):
    start, end = 0, len(seq_in)
    seq = [(x, key(x)) for x in seq_in]

    def partition_pred(x): return x[1] < seq[end-1][1]

    while start != end:
        pivot = (end + start) // 2
        seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
        pivot = partition(seq, partition_pred, start, end)
        seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
        if pivot == n: break
        if pivot < n: start = pivot + 1
        else: end = pivot

    seq_in[:] = (x for x, k in seq)

Given these, just replace your second (sort) line with:

nth_element(my_list, 4, key=f)
James Hopkin
The way I understand the key argument that has been added to the sort functions is that it is used to implement DSU (decorate-sort-undecorate) internally, such that the potentially-expensive key function is called only once for any element of the list. I think your method will call the key function many times for the same list element.
Paul McGuire
@Paul Good point - I've edited it to use DSU now.
James Hopkin