views:

170

answers:

2

A direct cut and paste of the following algorithm:

def msort[T](less: (T, T) => Boolean)
            (xs: List[T]): List[T] = {
  def merge(xs: List[T], ys: List[T]): List[T] =
    (xs, ys) match {
      case (Nil, _) => ys
      case (_, Nil) => xs
      case (x :: xs1, y :: ys1) =>
        if (less(x, y)) x :: merge(xs1, ys)
        else y :: merge(xs, ys1)
    }
  val n = xs.length / 2
  if (n == 0) xs
  else {
    val (ys, zs) = xs splitAt n
     merge(msort(less)(ys), msort(less)(zs))
  }
}

causes a StackOverflowError on 5000 long lists.

Is there any way to optimize this so that this doesn't occur?

+8  A: 

It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.

The latter solution goes like this:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
      case (Nil, _) => ys.reverse ::: acc 
      case (_, Nil) => xs.reverse ::: acc
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) merge(xs1, ys, x :: acc) 
        else merge(xs, ys1, y :: acc) 
    } 
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse
  } 
} 

Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream. The following code uses Stream just to prevent stack overflow, and List elsewhere:

def msort[T](less: (T, T) => Boolean) 
            (xs: List[T]): List[T] = { 
  def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
    case _ => if (left.isEmpty) right.toStream else left.toStream
  }
  val n = xs.length / 2 
  if (n == 0) xs 
  else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList
  } 
}
Daniel
I thought about trying to make it tail recursive, then saw quite a lot of info claiming that the JVM isn't that amenable and doesn't always optimize tail recursion. Is there some sort of guideline for when this succeeds?
The JVM doesn't, so the Scala compiler will do it for you. It only does under certain requirements: it must be self-recursion (ie, f calling g, and g calling f won't work), it must be _tail_ recursion, of course (the recursive call _must_ always be the last thing on that code path), on methods it must be either `final` or `private`. In the example, because `merge` is defined inside `msort`, instead of being defined on a class or object, it is effectively private.
Daniel
A: 

Just in case Daniel's solutions didn't make it clear enough, the problem is that merge's recursion is as deep as the length of the list, and it's not tail-recursion so it can't be converted into iteration.

Scala can convert Daniel's tail-recursive merge solution into something approximately equivalent to this:

def merge(xs: List[T], ys: List[T]): List[T] = {
  var acc:List[T] = Nil
  var decx = xs
  var decy = ys
  while (!decx.isEmpty || !decy.isEmpty) {
    (decx, decy) match { 
      case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil }
      case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil }
      case (x :: xs1, y :: ys1) => 
        if (less(x, y)) { acc = x :: acc ; decx = xs1 }
        else { acc = y :: acc ; decy = ys1 }
    }
  }
  acc
}

but it keeps track of all the variables for you.

(A tail-recursive method is one where the method only calls itself to get a complete answer to pass back; it never calls itself and then does something with the result before passing it back. Also, tail-recursion can't be used if the method might be polymorphic, so it generally only works in objects or with classes marked final.)

Rex Kerr