views:

4257

answers:

9

On a Django project, I was hoping to flatten a shallow list with a nested list comprehension, like this:

[image for image in menuitem.image_set.all() for menuitem in list_of_menuitems]

But I get in trouble of the NameError variety there, because the name 'menuitem' is not defined. After googling and looking around on Stack Overflow, I got the desired results with a reduce statement:

reduce(list.__add__, map(lambda x: list(x), [mi.image_set.all() for mi in list_of_menuitems]))

(Note, I need that list(x) call there because x is a Django QuerySet object.)

But the reduce method is fairly unreadable. So my question is:

Is there a simple way to flatten this list with a list comprehension, or failing that, what would you all consider to be the best way to flatten a shallow list like this, balancing performance and readability.

Update: Thanks to everyone who contributed to this question. Here is a summary of what I learned. I'm also making this a community wiki in case others want to add to or correct these observations.

My original reduce statement is redundant and is better written this way:

>>> reduce(list.__add__, (list(mi.image_set.all()) for mi in list_of_menuitems))

This is the correct syntax for a nested list comprehension (Brilliant summary dF!):

>>> [image for mi in list_of_menuitems for image in mi.image_set.all()]

But neither of these methods are as efficient as using itertools.chain:

>>> from itertools import chain
>>> list(chain(*[mi.image_set.all() for mi in h.get_image_menu()]))
+3  A: 

Off the top of my head, you can eliminate the lambda:

reduce(list.__add__, map(list, [mi.image_set.all() for mi in list_of_menuitems]))

Or even eliminate the map, since you've already got a list-comp:

reduce(list.__add__, [list(mi.image_set.all()) for mi in list_of_menuitems])

You can also just express this as a sum of lists:

sum([list(mi.image_set.all()) for mi in list_of_menuitems], [])
recursive
You could just use add, and I believe the second argument to sum is redundant.
daniel
It's not redundant. The default is zero, yielding TypeError: unsupported operand type(s) for +: 'int' and 'list'. IMO sum() is more direct than reduce(add, ...)
recursive
+1  A: 

What about:

from operator import add
reduce(add, map(lambda x: list(x.image_set.all()), [mi for mi in list_of_menuitems]))

But, Guido is recommending against performing too much in a single line of code since it reduces readability. There is minimal, if any, performance gain by performing what you want in a single line vs. multiple lines.

daniel
YES. This is exactly what I'm getting at.
Jason Baker
It's incredibly satisfying performing some crazy amount of work in a single line... but it's really just syntactic suger
daniel
If I remember correctly, Guido is actually recommending against the use of reduce and list comprehensions as well... I disagree though, they are incredibly useful.
daniel
Check the performance of this little nugget versus a multi-line function. I think you'll find that this one-liner is a real dog.
S.Lott
probably, mapping with lambdas is horrible. the overhead incurred for each function call sucks the life out of your code. I never said that that particular line was as fast as a multiple line solution... ;)
daniel
+20  A: 

If you're just looking to iterate over a flattened version of the data structure and don't need an indexable sequence, consider itertools.chain and company.

>>> list_of_menuitems = [['image00', 'image01'], ['image10'], []]
>>> import itertools
>>> chain = itertools.chain(*list_of_menuitems)
>>> print(list(chain))
['image00', 'image01', 'image10']

It will work on anything that's iterable, which should include Django's iterable QuerySets, which it appears that you're using in the question.

Edit: This is probably as good as a reduce anyway, because reduce will have the same overhead copying the items into the list that's being extended. chain will only incur this (same) overhead if you run list(chain) at the end.

Meta-Edit: Actually, it's less overhead than the question's proposed solution, because you throw away the temporary lists you create when you extend the original with the temporary.

Edit: As J.F. Sebastian says itertools.chain.from_iterable avoids the unpacking and you should use that to avoid * magic, but the timeit app shows negligible performance difference.

cdleary
+4  A: 

Here is the correct solution using list comprehensions (they're backward in the question):

>>> join = lambda it: (y for x in it for y in x)
>>> list(join([[1,2],[3,4,5],[]]))
[1, 2, 3, 4, 5]

In your case it would be

[image for menuitem in list_of_menuitems for image in menuitem.image_set.all()]

or you could use join and say

join(menuitem.image_set.all() for menuitem in list_of_menuitems)

In either case, the gotcha was the nesting of the for loops.

jleedev
+17  A: 

You almost have it! The way to do nested list comprehensions is to put the for statements in the same order as they would go in regular nested for statements.

Thus, this

for inner_list in outer_list:
    for item in inner_list:
        ...

corresponds to

[... for inner_list in outer_list for item in inner_list]

So you want

[image for menuitem in list_of_menuitems for image in menuitem.image_set.all()]
dF
+6  A: 

Performance Results. Revised.

import itertools
def itertools_flatten( aList ):
    return list( itertools.chain(*aList) )

from operator import add
def reduce_flatten1( aList ):
    return reduce(add, map(lambda x: list(x), [mi for mi in aList]))

def reduce_flatten2( aList ):
    return reduce(list.__add__, map(list, aList))

def comprehension_flatten( aList ):
    return list(y for x in aList for y in x)

I flattened a 2-level list of 30 items 1000 times

itertools_flatten     0.00554
comprehension_flatten 0.00815
reduce_flatten2       0.01103
reduce_flatten1       0.01404

Reduce is always a poor choice.

S.Lott
`map(lambda x: list(x), [mi for mi in aList]))` is a `map(list, aList)`.
J.F. Sebastian
`reduce_flatten = lambda list_of_iters: reduce(list.__add__, map(list, list_of_iters))`
J.F. Sebastian
`itertools_flatten2 = lambda aList: list(itertools.chain.from_iterable(aList))`
J.F. Sebastian
Don't have chain.from_iterable in 2.5.2 -- sorry -- can't compare with other solutions.
S.Lott
@recursive's version: `sum_flatten = lambda aList: sum(map(list, aList), [])`
J.F. Sebastian
+3  A: 

This solution works for arbitrary nesting depths - not just the "list of lists" depth that some (all?) of the other solutions are limited to:

def flatten(x):
    result = []
    for el in x:
        if hasattr(el, "__iter__") and not isinstance(el, basestring):
            result.extend(flatten(el))
        else:
            result.append(el)
    return result

It's the recursion which allows for arbitrary depth nesting - until you hit the maximum recursion depth, of course...

Alabaster Codify
It might be worth adding `hasattr(el, '__getitem__')` for compatibility with `iter()` function and builtin for-in loop (though all Python sequences (objects with `__getitem__`) also are iterable (object with `__iter__`)).
J.F. Sebastian
+12  A: 

@S.Lott: You inspired me to write a timeit app.

I figured it would also vary based on the number of partitions (number of iterators within the container list) -- your comment didn't mention how many partitions there were of the thirty items. This plot is flattening a thousand items in every run, with varying number of partitions. The items are evenly distributed among the partitions.

Flattening Comparison

Code (Python 2.6):

#!/usr/bin/env python2.6

"""Usage: %prog item_count"""

from __future__ import print_function

import collections
import itertools
import operator
from timeit import Timer
import sys

import matplotlib.pyplot as pyplot

def itertools_flatten(iter_lst):
    return list(itertools.chain(*iter_lst))

def itertools_iterable_flatten(iter_iter):
    return list(itertools.chain.from_iterable(iter_iter))

def reduce_flatten(iter_lst):
    return reduce(operator.add, map(list, iter_lst))

def reduce_lambda_flatten(iter_lst):
    return reduce(operator.add, map(lambda x: list(x), [i for i in iter_lst]))

def comprehension_flatten(iter_lst):
    return list(item for iter_ in iter_lst for item in iter_)

METHODS = ['itertools', 'itertools_iterable', 'reduce', 'reduce_lambda',
           'comprehension']

def _time_test_assert(iter_lst):
    """Make sure all methods produce an equivalent value.
    :raise AssertionError: On any non-equivalent value."""
    callables = (globals()[method + '_flatten'] for method in METHODS)
    results = [callable(iter_lst) for callable in callables]
    if not all(result == results[0] for result in results[1:]):
        raise AssertionError

def time_test(partition_count, item_count_per_partition, test_count=10000):
    """Run flatten methods on a list of :param:`partition_count` iterables.
    Normalize results over :param:`test_count` runs.
    :return: Mapping from method to (normalized) microseconds per pass.
    """
    iter_lst = [[dict()] * item_count_per_partition] * partition_count
    print('Partition count:    ', partition_count)
    print('Items per partition:', item_count_per_partition)
    _time_test_assert(iter_lst)
    test_str = 'flatten(%r)' % iter_lst
    result_by_method = {}
    for method in METHODS:
        setup_str = 'from test import %s_flatten as flatten' % method
        t = Timer(test_str, setup_str)
        per_pass = test_count * t.timeit(number=test_count) / test_count
        print('%20s: %.2f usec/pass' % (method, per_pass))
        result_by_method[method] = per_pass
    return result_by_method

if __name__ == '__main__':
    if len(sys.argv) != 2:
        raise ValueError('Need a number of items to flatten')
    item_count = int(sys.argv[1])
    partition_counts = []
    pass_times_by_method = collections.defaultdict(list)
    for partition_count in xrange(1, item_count):
        if item_count % partition_count != 0:
            continue
        items_per_partition = item_count / partition_count
        result_by_method = time_test(partition_count, items_per_partition)
        partition_counts.append(partition_count)
        for method, result in result_by_method.iteritems():
            pass_times_by_method[method].append(result)
    for method, pass_times in pass_times_by_method.iteritems():
        pyplot.plot(partition_counts, pass_times, label=method)
    pyplot.legend()
    pyplot.title('Flattening Comparison for %d Items' % item_count)
    pyplot.xlabel('Number of Partitions')
    pyplot.ylabel('Microseconds')
    pyplot.show()

Edit: Decided to make it community wiki.

Note: METHODS should probably be accumulated with a decorator, but I figure it'd be easier for people to read this way.

cdleary
Try `sum_flatten = lambda iter_lst: sum(map(list, iter_lst), [])`
J.F. Sebastian
or just sum(list, [])
Plumo
@EnTerr suggested `reduce(operator.iadd` http://stackoverflow.com/questions/3040335/finding-elements-in-python-association-lists-efficiently/3041450#3041450 that is the fastest so far (code: http://ideone.com/NWThp picture: http://i403.photobucket.com/albums/pp111/uber_ulrich/p1000.png )
J.F. Sebastian
`chain.from_iterable()` is slightly faster if there are many partitions http://i403.photobucket.com/albums/pp111/uber_ulrich/p10000.png
J.F. Sebastian
+4  A: 

In Python 2.6, using chain.from_iterable():

>>> from itertools import chain
>>> list(chain.from_iterable(mi.image_set.all() for mi in h.get_image_menu()))

It avoids creating of intermediate list.

J.F. Sebastian