views:

73

answers:

3

I needed to write a weighted version of random.choice (each element in the list has a different probability for being selected). This is what I came up with:

def weightedChoice(choices):
    """Like random.choice, but each element can have a different chance of
    being selected.

    choices can be any iterable containing iterables with two items each.
    Technically, they can have more than two items, the rest will just be
    ignored.  The first item is the thing being chosen, the second item is
    its weight.  The weights can be any numeric values, what matters is the
    relative differences between them.
    """
    space = {}
    current = 0
    for choice, weight in choices:
        if weight > 0:
            space[current] = choice
            current += weight
    rand = random.uniform(0, current)
    for key in sorted(space.keys() + [current]):
        if rand < key:
            return choice
        choice = space[key]
    return None

This function seems overly complex to me, and ugly. I'm hoping everyone here can offer some suggestions on improving it or alternate ways of doing this. Efficiency isn't as important to me as code cleanliness and readability.

+4  A: 
def weighted_choice(choices):
   total = sum(w for c,w in choices)
   r = random.uniform(0, total)
   upto = 0
   for c, w in choices:
      if upto+w > r:
         return c
      upto += w
   assert False, "Shouldn't get here"
Ned Batchelder
I don't know why I thought I had to sort the weights and go through them in order...this is better.
Colin
+2  A: 

Crude, but may be sufficient:

import random
weighted_choice = lambda s : random.choice(sum(([v]*wt for v,wt in s),[]))

Does it work?

# define choices and relative weights
choices = [("WHITE",90), ("RED",8), ("GREEN",2)]

# initialize tally dict
tally = dict((c[0],0) for c in choices)

# tally up 1000 weighted choices
for i in xrange(1000):
    tally[weighted_choice(choices)] += 1

print tally.items()

Prints:

[('WHITE', 904), ('GREEN', 22), ('RED', 74)]

Assumes that all weights are integers. They don't have to add up to 100, I just did that to make the test results easier to interpret.

Paul McGuire
Nice, I'm not sure I can assume all weights are integers, though.
Colin
+1  A: 

I'd require the sum of choices is 1, but this works anyway

def weightedChoice(choices):
    # Safety check, you can remove it
    for c,w in choices:
        assert w >= 0


    tmp = random.uniform(0, sum(c for c,w in choices))
    for choice,weight in choices:
        if tmp < weight:
            return choice
        else:
            tmp -= weight
     raise ValueError('Negative values in input')
phihag
Out of curiosity, is there a reason you prefer random.random() * total instead of random.uniform(0, total)?
Colin
@Colin No, not at all. Updated.
phihag
You traverse three times over iterable. This might be not supported by iterable.
liori
That's a good point. I've only been passing in lists of tuples, so I hadn't uncovered that bug yet.
Colin
@liori You're right. However, weightedChoice can not be computated without storing all the items of the iterable in a list anyway, so the input should be a list.
phihag
I think it is actually possible. http://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf It is actually pretty simple... But who cares...
liori
@liori I do care, and you're right: weightedChoice *can* be computed with one iterator pass only. However, this seems to require more than 1 call to the pseudo random generator.
phihag