views:

170

answers:

3

When recursing with Lists in Java, I frequently end up allocating and copying lists many many times. For example, I want to generate a List<List<Integer>> of all possible Integer sequences where:

  • Every number is between 1 and 9
  • Every number is greater or equal to the number before it
  • A number can appear between 0 and 3 times in a row.

For example, [1,1,1,2,2,2,3,3,3,4,4,4,5,5,5,6,6,6,7,7,7,8,8,8,9,9,9] is the largest sequence. I have a recursive method that does this:

static void recurse(List<List<Integer>> addto, List<Integer> prev, int n){
 if(n>=10)
  addto.add(prev);
 else{
  for(int i=0; i<=3; i++){
   List<Integer> newlist = new ArrayList<Integer>(prev);
   for(int k=0; k<i; k++){
    newlist.add(n);
   }
   recurse(addto, newlist, n+1);
  }
 }
}

What happens here is that I'm copying the entire prev list 3 times every recursion. I need to do this in order to concatenate stuff to my list and pass it on to the next iteration. This is very slow (2 seconds). A less elegent version using 10 nested loops ran much faster because it did not have to copy so many lists. What's the 'proper' way to do this?

By the way this is not homework, but related to one of the USACO problems.

+1  A: 

The slow-down may be due to the memory used internally by the ArrayLists being reallocated. By default, an ArrayList starts with a capacity of 10. When you add the 11th element, it has to expand that, but it only expands by 50%. Moreover, when you create an ArrayList with the copy constructor, the new list ends up with an even smaller capacity -- the actual number of elements in the source list plus 10% I think. (I'm guessing your 10-loop version of the algorithm used just a single "working" list, which it made a copy of just before adding to the List<List<Integer>>).

So you could try providing a capacity when creating the lists, and see if that speeds things up at all:

List<Integer> newlist = new ArrayList<Integer>(27);  // longest list size is 9 * 3
newlist.addAll(prev);

EDIT: By the way, you should be able to implement a non-recursive algorithm without 10 nested loops. Use a stack, similar to a depth-first tree search.

Todd Owen
As others have pointed out, there's also the overhead of creating Integer objects, but this doesn't explain why your non-recursive algorithm (which presumably also used lists of Integers) is faster. Actually, because Integers are immutable you could circumvent this issue by first building a list of the nine Integer objects you need (new Integer(1), new Integer(2), etc) and using these instead of creating Integers afresh! This may seem odd, but it would actually make the algorithm easy to generalize to work for any list of objects, not just numbers.
Todd Owen
A: 

Rather than copying the list, you should modify the list in-place, and only copy it when you find a solution. After you get out of the recursion, remove the last three elements of the list.

static void recurse(List<List<Integer>> addto, List<Integer> list, int n){
        if(n>=10)
                addto.add(new ArrayList<Integer>(list));
        else{
                int pos = list.size();
                for(int i=0; i<=3; i++){
                        list.add(n);
                        recurse(addto, newlist, n+1);
                }
                for(int i=2; i>=0; i--){
                        list.remove(i);
                }
        }
}

If you want even more performance, try using an int[], rather than an ArrayList of Integer, as that will save you creation of the Integers. You could size the list array with 27 elements, and pass the index of the first free position in the recursion.

Martin v. Löwis
A: 

Avoiding object creation and recursion is a start. Here is a non-recursive method:

/**
 * @author clint
 * 
 */
public class ListBuilder {

  public static void main(String[] args) {
    buildLists(0, 3, 1, 9);
  }

  private static int[][] buildLists( int minOccur, int maxOccur, int minNum, int maxNum ) {
    long time = System.currentTimeMillis();
   assert( minOccur >= 0 );
   assert( minOccur < maxOccur );
   assert( minNum < maxNum );
   int occurDelta = maxOccur - minOccur;
   int numDelta = maxNum - minNum;

   int[][] lists = new int[ (int)Math.pow(occurDelta + 1,numDelta +1) ][];
   int[] counters = new int[ numDelta + 1];
   for ( int i = 0; i < counters.length; i++ ) {
     counters[i] = minOccur;
   }

   int listIndex;
   int sumCounters;

   for ( int i = 0; i < lists.length; i++ ) {
     //System.out.println("c:" + Arrays.toString(counters));
     sumCounters = 0;
     for ( int j = 0; j < counters.length; j++ ) {
       sumCounters += counters[j]; 
     }
     lists[i] = new int[sumCounters];
     listIndex = 0;
     for ( int j = 0; j < counters.length; j++ ) {
       for ( int k = 0; k < counters[j]; k++ ) {
         lists[i][listIndex] = j + minNum;
         listIndex++;
       }
     }
     for ( int j = 0; j < counters.length; j++ ) {
       counters[j] += 1;
       if ( counters[j] > maxOccur ) {
         counters[j] = minOccur;
       } else {
         break;
       } 
     }  
   }

//   for ( int i = 0; i < lists.length; i++ ) {
//     System.out.println(Arrays.toString(lists[i]));
//   }
   long dt = System.currentTimeMillis() - time;

   System.out.println(lists.length + " lists in " + dt + "ms");
   return lists;
  }
}
Clint