tags:

views:

135

answers:

5

Hey all,

While optimizing performance of an app of mine, I ran across a huge performance bottleneck in few lines of (Python) code.

I have N tokens. each token has a value assigned to it. Some of the tokens contradict (e.g. tokens 8 and 12 cannot "live together"). My job is to find the k-best token-groups. The value of a group of tokens is simply the sum of the values of the tokens in it.

Naïve algorithm (which I have implemented...):

  1. find all 2^N token-group permutations of the tokens
  2. Eliminate the token-groups that have contradictions in them
  3. Calculate the value of all remaining token-groups
  4. Sort token-groups by value
  5. Choose top K token-groups

Real world numbers - I need top 10 token groups from a group of 20 tokens (for which I calculated the 1,000,000 permutations (!)), narrowed down to 3500 non-contradicting token groups. This took 5 seconds on my laptop...

I'm sure I can optimize steps 1+2 somehow by generating just the non-contradicting token-groups.

I'm also pretty sure I can somehow magically find the best token-group in a single search and find a way to traverse the token-groups by diminishing value, thus finding just the 10-best I am looking for....

my actual code:

all_possibilities = sum((list(itertools.combinations(token_list, i)) for i in xrange(len(token_list)+1)), [])
all_possibilities = [list(option) for option in all_possibilities if self._no_contradiction(option)] 
all_possibilities = [(option, self._probability(option)) for option in all_possibilities]
all_possibilities.sort(key = lambda result: -result[1]) # sort by descending probability

Please help?

Tal.

+3  A: 

A simple approach at steps 1+2 could look like this: first, define a list of tokens and a dictionary of contradictions (each key is a token and each value is a set of tokens). Then, for each token take two actions:

  • add it to the result if it is not already contradicting, and increase the conflicting set with tokens that contradict the currently added token
  • don't add it to the result (choose to ignore it) and move to the next token.

So here's a sample code:

token_list = ['a', 'b', 'c']

contradictions = {
    'a': set(['b']),
    'b': set(['a']),
    'c': set()
}

class Generator(object):
    def __init__(self, token_list, contradictions):
        self.list = token_list
        self.contradictions = contradictions
        self.max_start = len(self.list) - 1

    def add_no(self, start, result, conflicting):
        if start < self.max_start:
            for g in self.gen(start + 1, result, conflicting):
                yield g
        else:
            yield result[:]

    def add_yes(self, token, start, result, conflicting):
        result.append(token)
        new_conflicting = conflicting | self.contradictions[token]
        for g in self.add_no(start, result, new_conflicting):
            yield g
        result.pop()

    def gen(self, start, result, conflicting):
        token = self.list[start]
        if token not in conflicting:
            for g in self.add_yes(token, start, result, conflicting):
                yield g
        for g in self.add_no(start, result, conflicting):
            yield g

    def go(self):
        return self.gen(0, [], set())

Sample usage:

g = Generator(token_list, contradictions)
for x in g.go():
    print x

This is a recursive algorithm, so it won't work for more than a few thousand tokens (because of Python's stack limit), but you could easily create a non-recursive one.

DzinX
Python does allow you to increase the stack limit using `sys.setrecursionlimit`, but yeah, recursion is considered to be not quite as "Pythonic" as iteration, iirc.
JAB
I don't understand it yet (it is midnight...), but it works and it is 100X faster than my code! Thanks! Reading on after a coffee break.
Tal Weiss
OK, now I understand it. It works and is fast but it is much less elegant than the other solutions suggested. You did suggest the very cool dictionary of sets for contradictions which everyone copied. Thanks for your contribution! +1
Tal Weiss
+2  A: 

Here's a possible "heuristically optimized" approach and a small sample:

import itertools

# tokens in decreasing order of value (must all be > 0)
toks = 12, 11, 8, 7, 6, 2, 1

# contradictions (dict highestvaltok -> set of incompatible ones)
cont = {12: set([11, 8, 7, 2]),
    11: set([8, 7, 6]),
         7: set([2]),
     2: set([1]),
       }

rec_calls = 0

def bestgroup(toks, contdict, arein=(), contset=()):
  """Recursively compute the highest-valued non-contradictory subset of toks."""
  global rec_calls
  toks = list(toks)
  while toks:
    # find the top token compatible w/the ones in `arein`
    toptok = toks.pop(0)
    if toptok in contset:
      continue
    # try to extend with and without this toptok
    without_top = bestgroup(toks, contdict, arein, contset)
    contset = set(contset).union(c for c in contdict.get(toptok, ()))
    newarein = arein + (toptok,)
    with_top = bestgroup(toks, contdict, newarein, contset)
    rec_calls += 1
    if sum(with_top) > sum(without_top):
      return with_top
    else:
      return without_top
  return arein

def noncongroups(toks, contdict):
  """Count possible, non-contradictory subsets of toks."""
  tot = 0
  for l in range(1, len(toks) + 1):
    for c in itertools.combinations(toks, l):
      if any(cont[k].intersection(c) for k in c if k in contdict): continue
      tot += 1
  return tot


print bestgroup(toks, cont)
print 'calls: %d (vs %d of %d)' % (rec_calls, noncongroups(toks, cont), 2**len(toks))

I believe this always makes as many recursive calls as feasible (non-contradictory) subsets exist, but haven't proven it (so I'm just counting both -- the noncongroups of course has nothing to do with the solution, it's there just to check that behavioral property;-).

If this produces an acceptable speedup on your "actual use cases" benchmarks, then further optimization may introduce alpha-pruning (so you can stop recursion along paths that you know to be unproductive -- that's the motivation for the descending order in the tokens;-) and recursion elimination (using an explicit stack within the function instead). But I wanted to keep this first version simple, so it can easily be understood and verified (also, the further optimizations I have in mind are only going to help marginally, I suspect -- say, at best, halving the typical runtime, if even that much).

Alex Martelli
Thank you for the detailed reply! It works and generates the best answer (sadly, not the k-best answers which I need, but I can complete that myself). +1
Tal Weiss
+3  A: 

An O(n (log n)) or O(n + m) solution for n tokens and string-length m

What differentiates your problem from the NP-complete clique problem is the fact that your "conflict" graph has structure - namely that it can be projected onto 1 dimension (it can be sorted).

That means you can divide and conquer; after all, non-overlapping ranges have no effect on each other, so there is no need to explore the complete state-space. In particular, a dynamic programming solution will work.

The outline of an algorithm

  1. Assume a token's position is represented as [start, end) (i.e. inclusive start, exclusive end). Sort the token-list by token end, we'll be iterating over them.
  2. You will be extending subsets of these tokens. These sets of tokens will have an end (no token can be added to the subset if it starts before the subset's end), and a cumulative value. The end of a subset of tokens is the maximum of the ends of all tokens in the subset.
  3. You're going to maintain a mapping (e.g. via a hashtable or array) from the index into the sorted array of tokens up to which everything's been processed to the resultant best-yet subset of non-conflicting tokens. That means that the best-yet subset stored in the mapping for index J must can only include tokens of index less than or equal to J
  4. At each step, you'll be computing the best subset for some position J, and then one of three things can occur: you may have already cached this computation in the mapping (easy), or the best subset includes the item J, or the best subset exludes item J. If you haven't cached it, you can only find out it the best subset includes or excludes J by trying both options.

Now, the trick is in the cache - you need to try both options, and that looks like a recursive (exponential) search, but it needn't be.

  • If the best subset for index J includes token[J] then it can't include any tokens that overlap that token - and in particular, since we sorted by token.end, there is a last token K in that list such that K < J and token[K].end <= token[J].start: and for that token K we can compute the best subset too (or maybe we already have it cached).
  • On the other hand, it may exclude token[J], but then the best subset is simply token[J-1].
  • In either case, a special case token[-1] with token[-1].end = 0 and subset value 0 can form the base case.

Since you only need to do this computation once for each token index, this part is actually linear in the number of tokens. However, sorting the tokens naively (which I'd recommend) is O(n log(n)) and finding a last token index given an string position is O(log(n)) - repeated n times; so the overall running time is O(n log(n)). You can reduce this to O(n) by observing that you don't need to sort an arbitrary list - the maximal string position is limited and small so you can do the sorting by indexing in the string, but it's almost certainly not worth it. Similarly, although finding one token by binary search is log n you can do this by aligning two lists instead - one sorted on token end, the other on token start - thus permitting an O(n + m) implementation. Unless n can really get huge, it's not worth it.

If you iterate from the front of the string to the end, since all lookups look "back", you can remove the recursion entirely and simply directly lookup the result for a given index since it must have been calculated already anyhow.

Does this rather vague explanation help? It's a basic application of dynamic programming, which is just a fancy word for caching; so if you're confused, that's what you should read up on.

Extending this to the top k-best solutions

If you want to find the top-K best solutions, you'll need a messy but doable extension that maps index-of token not to the single best subset, but to the best-K subsets so far - obviously at increased computational cost and a bit of extra code. Essentially, rather than picking to either include or not include token[J], you'll take the set union and trim down to the k-best options at each token-index. That's O(n log(n) + n k log(k)) if implemented straightforwardly.

Eamon Nerbonne
Wow - thanks for the detailed reply! +1 for also noticing that it is a clique problem (thanks to you I now know what that is...).
Tal Weiss
A: 

The following solution generates all maximal non-contradicting subsets, taking advantage of the fact that there's no point omitting an element from the solution unless it contradicts another element in the solution.

The simple optimization to avoid the second recursion in the case that the element t doesn't contradict any of the remaining elements should help make this solution efficient if the number of contradictions is small.

def solve(tokens, contradictions):
   if not tokens:
      yield set()
   else:
      tokens = set(tokens)
      t = tokens.pop()
      for solution in solve(tokens - contradictions[t], contradictions):
         yield solution | set([t])
      if contradictions[t] & tokens:
         for solution in solve(tokens, contradictions):
            if contradictions[t] & solution:
               yield solution

This solution also demonstrates that dynamic programming (aka memoization) may be helpful to improve the performance of the solution further for some types of inputs.

jchl
Looks nice (less than the non-recursive solution though), but did not work correctly for me... See: http://pastebin.com/QAF0UbEXMaybe I did something wrong...
Tal Weiss
Can you explain in what way it didn't work? That example prints [set(['a', 'c']), set(['c', 'b'])], which is correct. Note that, for efficiency, my solution deliberately only includes _maximal_ non-contradicting subsets, i.e. those that cannot be extended with another non-contradicting element. Under the assumption that all your values are non-negative, the maximal subsets are the only ones you need in order to find the subset with the highest value. Though now I see that you need the K best subsets; if you're interested in non-maximal subsets too, then I'll update my solution.
jchl
+2  A: 

A really simple way to get all the non-contradicting token-groups:

#!/usr/bin/env python

token_list = ['a', 'b', 'c']

contradictions = {
    'a': set(['b']),
    'b': set(['a']),
    'c': set()
}

result = []

while token_list:
    token = token_list.pop()
    new = [set([token])]
    for r in result:
        if token not in contradictions or not r & contradictions[token]:
            new.append(r | set([token]))
    result.extend(new)

print result
Florian Diesch
A. It is beautiful! Thank you!B. No Recursion - nice!C. Minor comment: you left out the empty result, easily added to the result =[] line.D. No need to "if token not in contradictions" since we are generating the dictionary ourselves (the result is faster)E. This is 145 X faster than my code!!!
Tal Weiss
Err... add the empty set at the end of the algo. Adding it as the seed generates duplicate results.
Tal Weiss
I thought the empty set isn't needed as it's always there with value 0
Florian Diesch
Nice! Significantly less complicated than my solution :) +1
DzinX
A more elegant approach is to add the empty set at the beginning of the algorithm (define result = [set()]). Then you don't need to special-case the single-element sets, so you can just define new = [].
jchl