views:

1952

answers:

4

A friend gave me this code snippet in Closure

(defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc))))
(time (sum (range 1 9999999) 0))

and asked me how does it fare against a similar Scala implementation.

The Scala code I've written looks like this:

def from(n: Int): Stream[Int] = Stream.cons(n, from(n+1))
val ints = from(1).take(9999998)

def add(a: Stream[Int], b: Long): Long = {
    if (a.isEmpty) b else add(a.tail, b + a.head)
}

val t1 = System.currentTimeMillis()
println(add(ints, 0))
val t2 = System.currentTimeMillis()
println((t2 - t1).asInstanceOf[Float] + " msecs")

Bottom line is: the code in Clojure runs in about 1.8 seconds on my machine and uses less than 5MB of heap, the code in Scala runs in about 12 seconds and 512MB of heap aren't enough (it finishes the computation if I set the heap to 1GB).

So I'm wondering why is Clojure so much faster and slimmer in this particular case? Do you have a Scala implementation that has a similar behavior in terms of speed and memory usage?

Please refrain from religious remarks, my interest lies in finding out primarily what makes clojure so fast in this case and if there's a faster implementation of the algo in scala. Thanks.

+6  A: 

I would suspect it's due to how Clojure handles tail-cail optimizations. Since the JVM doesn't natively perform this optimization (and both Clojure and Scala run on it), Clojure optimizes tail recursion through the recur keyword. From the Clojure site:

In functional languages looping and iteration are replaced/implemented via recursive function calls. Many such languages guarantee that function calls made in tail position do not consume stack space, and thus recursive loops utilize constant space. Since Clojure uses the Java calling conventions, it cannot, and does not, make the same tail call optimization guarantees. Instead, it provides the recur special operator, which does constant-space recursive looping by rebinding and jumping to the nearest enclosing loop or function frame. While not as general as tail-call-optimization, it allows most of the same elegant constructs, and offers the advantage of checking that calls to recur can only happen in a tail position.

EDIT: Scala optimizes tail calls also, as long as they're in a certain form. However, as the previous link shows, Scala can only do this for very simple cases:

In fact, this is a feature of the Scala compiler called tail call optimization. It optimizes away the recursive call. This feature works only in simple cases as above, though. If the recursion is indirect, for example, Scala cannot optimize tail calls, because of the limited JVM instruction set.

Without actually compiling and decompiling your code to see what JVM instructions are produced, I suspect it's just not one of those simple cases (as Michael put it, due to having to fetch a.tail on each recursive step) and thus Scala just can't optimize it.

Chris Bunch
I'm using scala 2.7.5 and I think it's supposed to do t-c-o in the scenario I'm using.
Li Lo
I guess you better make sure it is, then :-)
Vinko Vrsalovic
Based on the decompiled bytecode below it looks like t-c-o is being done.public long add(scala.Stream, long); Code: 0: aload_1 1: invokeinterface #103, 1; //InterfaceMethod scala/Seq.isEmpty:()Z 6: ifeq 11 9: lload_2 10: lreturn 11: aload_1 12: invokevirtual #106; //Method scala/Stream.tail:()Lscala/Stream; 15: lload_2 16: aload_1 17: invokevirtual #110; //Method scala/Stream.head 20: invokestatic #114; //Method scala/runtime/BoxesRunTime.unboxToInt 23: i2l 24: ladd 25: lstore_2 26: astore_1 27: goto 0
Li Lo
+6  A: 

Profiled this example of yours and it seems that the class Stream (well... some anonymous function related to it - forgot its name as visualvm crashed on me) occupies most of the heap. It's related to the fact that Streams in Scala do leak memory - see Scala Trac #692. Fixes are due in Scala 2.8.. EDIT: Daniel's comment rightly pointed out that it is not related to this bug. It's because "val ints points to the Stream head, the garbage collector can't collect anything" [Daniel]. I found the comments in this bug report nice to read though, in relation to this question.

In your add function, you are holding a reference to a.head, therefore the garbage collector cannot collect the head, leading to a stream that holds 9999998 elements in the end, which cannot be GC-ed.

[A little interlude]

You may also keep copies of the tails you keep passing, I am not sure how Streams deal with that. If you would use a list, tails would not be copied. For example:

val xs =  List(1,2,3)
val ys = 1 :: xs
val zs = 2 :: xs

Here, both ys and zs 'share' the same tail, at least heap-wise (ys.tail eq zs.tail, aka reference equality yields true).

[This little interlude was to make the point that passing a lot of tails is not a really bad thing in principle :), they are not copied, at least for lists]

An alternative implementation (which runs quite fast, and I think it is more clear than the pure functional one) is to use an imperative approach:

def addTo(n: Int, init: Int): Long = {
  var sum = init.toLong
  for(i <- 1 to n) sum += i
  sum
}

scala> addTo(9999998, 0)

In Scala it is quite OK to use an imperative approach, for performance and clarity (at least to me, this version of add is more clear to its intent). For even more conciseness, you could even write

(1 to 9999998).reduceLeft(_ + _)

(runs a bit slower, but still reasonable and doesn't blow the memory up)

I believe that Clojure might be faster as it is fully functional, therefore more optimisations are possible than with Scala (which blends functional, OO and imperative). I am not very familiar with Clojure though.

Hope this helps :)

-- Flaviu Cipcigan

Flaviu Cipcigan
It's not related to the bug. Because `val ints` points to the `Stream` head, the garbage collector can't collect anything.
Daniel
+20  A: 

First, Scala only optimises tail calls if you invoke it with -optimise. Edit: It seems Scala will always optimise tail-call recursions if it can, even without -optimise.

Second, Stream and Range are two very different things. A Range has a beginning and an end, and its projection has just a counter and the end. A Stream is a list which will be computed on-demand. Since you are adding the whole ints, you'll compute, and, therefore, allocate, the whole Stream.

A closer code would be:

import scala.annotation.tailrec

def add(r: Range) = {
  @tailrec 
  def f(i: Iterator[Int], acc: Long): Long = 
    if (i.hasNext) f(i, acc + i.next) else acc

  f(r iterator, 0)
}

def time(f: => Unit) {
  val t1 = System.currentTimeMillis()
  f
  val t2 = System.currentTimeMillis()
  println((t2 - t1).asInstanceOf[Float]+" msecs")
}

Normal run:

scala> time(println(add(1 to 9999999)))
49999995000000
563.0 msecs

On Scala 2.7 you need "elements" instead of "iterator", and there's no "tailrec" annotation -- that annotation is used just to complain if a definition can't be optimized with tail recursion -- so you'll need to strip "@tailrec" as well as the "import scala.annotation.tailrec" from the code.

Also, some considerations on alternate implementations. The simplest:

scala> time(println(1 to 9999999 reduceLeft (_+_)))
-2014260032
640.0 msecs

On average, with multiple runs here, it is slower. It's also incorrect, because it works just with Int. A correct one:

scala> time(println((1 to 9999999 foldLeft 0L)(_+_)))
49999995000000
797.0 msecs

That's slower still, running here. I honestly wouldn't have expected it to run slower, but each interation calls to the function being passed. Once you consider that, it's a pretty good time compared to the recursive version.

Daniel
Granted, this accounts for the increased memory usage. What about the increased computation time?
Li Lo
The increased computation time is spent allocating memory, and fruitlessly trying to garbage collect it.
Daniel
If you used a pool of recycled objects does it speed up a lot? The JVM handles short-lived heap objects with the efficiency more like a stack, so it would surprise me if GC was really taking a lot of time.
Bill K
The problem with Stream is that it's allocating the objects but not freeing them. Range will use short-lived heap objects, which runs much faster.
Jorge Ortiz
@Bill K: beware of such claims. Java's handling of short lived heap objects is nowhere close to stack efficiency, it's just better than long-lived objects. Stack deallocation is O(1), while short lived heap is O(n), where n is the number of objects. At any rate, yes, it would have better performance than the millions of unrecyclable objects as the original solution resulted in, but it would still lose to the tail-recursive range-iterator solution.
Daniel
@Daniel Short lived heap deallocation (in java) is O(1) which is why I said it's pretty much as fast as a stack. Look for a white paper on Java allocation and a structure called "Eden"
Bill K
@Bill K: I'm familiar with garbage collectors. The "constant time" garbage collectors all do a fixed amount of work -- which is proportional to the amount of memory they'll process. In Java's case, the linear factor comes from the effort of identifying which Eden objects must be preserved, and copying them. Even short lived objects might be live when Eden gets full. So we have time linear to the size of roots and live objects. Things like card marking optimize the GC, but do not change its linearity. Meanwhile, stack deallocation is truly O(1): `SP = BP; BP = POP SP`.
Daniel
+17  A: 

Clojure's range does not memoize, Scala's Stream does. Totally different data structures with totally different results. Scala does have a non memoizing Range structure, but it's currently kind of awkard to work with in this simple recursive way. Here's my take on the whole thing.

Using Clojure 1.0 on an older box, which is slow, I get 3.6 seconds

user=> (defn sum [coll acc] (if (empty? coll) acc (recur (rest coll) (+ (first coll) acc))))
#'user/sum
user=> (time (sum (range 1 9999999) 0))
"Elapsed time: 3651.751139 msecs"
49999985000001

A literal translation to Scala requires me to write some code

def time[T](x : => T) =  {
  val start = System.nanoTime : Double
  val result = x
  val duration = (System.nanoTime : Double) - start
  println("Elapsed time " + duration / 1000000.0 + " msecs")
  result
}

It's good to make sure that that's right

scala> time (Thread sleep 1000)
Elapsed time 1000.277967 msecs

Now we need an unmemoized Range with similar semantics to Clojure's

case class MyRange(start : Int, end : Int) {
  def isEmpty = start >= end
  def first = if (!isEmpty) start else error("empty range")
  def rest = new MyRange(start + 1, end)
}

From that "add" follows directly

def add(a: MyRange, b: Long): Long = {
    if (a.isEmpty) b else add(a.rest, b + a.first)
}

And it times much faster than Clojure's on the same box

scala> time(add(MyRange(1, 9999999), 0))
Elapsed time 252.526784 msecs
res1: Long = 49999985000001

Using Scala's standard library Range, you can do a fold. It's not as fast as simple primitive recursion, but its less code and still faster than the Clojure recursive version (at least on my box).

scala> time((1 until 9999999 foldLeft 0L)(_ + _))
Elapsed time 1995.566127 msecs
res2: Long = 49999985000001

Contrast with a fold over a memoized Stream

time((Stream from 1 take 9999998 foldLeft 0L)(_ + _)) 
Elapsed time 3879.991318 msecs
res3: Long = 49999985000001
James Iry