It's actually possible in linear time (O(N)) on average.
You need a partition algorithm:
def partition(seq, pred, start=0, end=-1):
if end == -1: end = len(seq)
while True:
while True:
if start == end: return start
if not pred(seq[start]): break
start += 1
while True:
if pred(seq[end-1]): break
end -= 1
if start == end: return start
seq[start], seq[end-1] = seq[end-1], seq[start]
start += 1
end -= 1
which can be used by an nth_element algorithm:
def nth_element(seq_in, n, key=lambda x:x):
start, end = 0, len(seq_in)
seq = [(x, key(x)) for x in seq_in]
def partition_pred(x): return x[1] < seq[end-1][1]
while start != end:
pivot = (end + start) // 2
seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
pivot = partition(seq, partition_pred, start, end)
seq[pivot], seq[end - 1] = seq[end - 1], seq[pivot]
if pivot == n: break
if pivot < n: start = pivot + 1
else: end = pivot
seq_in[:] = (x for x, k in seq)
Given these, just replace your second (sort) line with:
nth_element(my_list, 4, key=f)