tags:

views:

311

answers:

3

I am fairly new to to Scala and am still trying to develop a feel for which approaches are efficient and which might contain hidden performance costs.

If I define a (non-tail) recursive function which contains an inner function. Are multiple copies of the inner function's functional object instantiated for each recursive call?

For example in the following:

def sumDoubles(n: Int): Int = {
    def dbl(a: Int) = 2 * a;
    if(n > 0)
        dbl(n) + sumDoubles(n - 1)
    else
        0               
}

...how many copies of the dbl object exist on the stack for a call to sumDoubles(15)?

+1  A: 

In this special case, the compiler might possibly optimize this away, but consider the following (pseudo-code).

def func(n) = {
    def nTimes(a) = n * a
    if (n <= 1)
        1
    else
        nTimes(func(n - 1))
}

In most cases, the inner function needs to access variables of its outer function, so it has to be re-instantiated in each call.

Dario
In your code no objects no objects need to be "instantiated." The Scala compiler would lift it into something likeprivate[this] def nTimes(a : Int, n : Int) = n * a def func(n : Int) : Int = { if (n <= 1) 1 else nTimes(func(n - 1), n) }
James Iry
+6  A: 

If you're concerned about Scala performance, it's good to be familiar with 1) how Java bytecode performs, and 2) how Scala translates to Java bytecode. If you're comfortable looking at raw bytecode or decompiling it, I suggest you do so for areas where you might be concerned about performance. You'll pretty quickly get a feel for how Scala translates to bytecode. If not, you can use the scalac -print flag, which prints a "fully desugared" version of your Scala code. It's basically a version of your code as close to Java as possible, right before it gets turned into bytecode.

I've made a file Performance.scala with the code you posted:

jorge-ortizs-macbook-pro:sandbox jeortiz$ cat Performance.scala 
object Performance {
  def sumDoubles(n: Int): Int = {
      def dbl(a: Int) = 2 * a;
      if(n > 0)
          dbl(n) + sumDoubles(n - 1)
      else
          0               
  }
}

When I run scalac -print on it I can see the desugared Scala:

jorge-ortizs-macbook-pro:sandbox jeortiz$ scalac Performance.scala -print
[[syntax trees at end of cleanup]]// Scala source: Performance.scala
package <empty> {
  final class Performance extends java.lang.Object with ScalaObject {
    @remote def $tag(): Int = scala.ScalaObject$class.$tag(Performance.this);
    def sumDoubles(n: Int): Int = {
      if (n.>(0))
        Performance.this.dbl$1(n).+(Performance.this.sumDoubles(n.-(1)))
      else
        0
    };
    final private[this] def dbl$1(a: Int): Int = 2.*(a);
    def this(): object Performance = {
      Performance.super.this();
      ()
    }
  }
}

Then you'll notice that dbl has been "lifted" into a final private[this] method of the same object that sumDoubles belongs to. Both dbl and sumDoubles are actually methods on their containing object, not functions. Calling them non-tail-recursively might grow your stack, but it won't instantiate objects on your heap.

Jorge Ortiz
+5  A: 

At the bytecode level

def sumDoubles(n: Int): Int = {
  def dbl(a: Int) = 2 * a;
  if(n > 0)
    dbl(n) + sumDoubles(n - 1)
  else
    0               
}

is exactly the same as

private[this] def dbl(a: Int) = 2 * a;

def sumDoubles(n: Int): Int = {
  if(n > 0)
    dbl(n) + sumDoubles(n - 1)
  else
    0               
}

But don't take my word for it

~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial #10; //Method java/lang/Object."":()V
   4:   return

private final int dbl$1(int);
  Code:
   0:   iconst_2
   1:   iload_1
   2:   imul
   3:   ireturn

public int sumDoubles(int);
  Code:
   0:   iload_1
   1:   iconst_0
   2:   if_icmple 21
   5:   aload_0
   6:   iload_1
   7:   invokespecial #22; //Method dbl$1:(I)I
   10:  aload_0
   11:  iload_1
   12:  iconst_1
   13:  isub
   14:  invokevirtual #24; //Method sumDoubles:(I)I
   17:  iadd
   18:  goto 22
   21:  iconst_0
   22:  ireturn

}

If an inner function captures an immutable variable then there's a translation. This code

def foo(n: Int): Int = {
  def dbl(a: Int) = a * n;
    if(n > 0)
      dbl(n) + foo(n - 1)
    else
      0               
}

Gets translated into

private[this] def dbl(a: Int, n: Int) = a * n;

def foo(n: Int): Int = {
  if(n > 0)
    dbl(n, n) + foo(n - 1)
  else
    0               
}

Again, the tools are there for you

~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial #10; //Method java/lang/Object."":()V
   4:   return

private final int dbl$1(int, int);
  Code:
   0:   iload_1
   1:   iload_2
   2:   imul
   3:   ireturn

public int foo(int);
  Code:
   0:   iload_1
   1:   iconst_0
   2:   if_icmple 22
   5:   aload_0
   6:   iload_1
   7:   iload_1
   8:   invokespecial #23; //Method dbl$1:(II)I
   11:  aload_0
   12:  iload_1
   13:  iconst_1
   14:  isub
   15:  invokevirtual #25; //Method foo:(I)I
   18:  iadd
   19:  goto 23
   22:  iconst_0
   23:  ireturn

}

If mutable variable is captured then it has to be boxed which can be more expensive.

def bar(_n : Int) : Int = {
   var n = _n
   def subtract() = n = n - 1

   if (n > 0) {
      subtract
      n
   }
   else
      0
}

Gets translated into something like

private[this] def subtract(n : IntRef]) = n.value = n.value - 1

def bar(_n : Int) : Int = {
   var n = _n
   if (n > 0) {
      val nRef = IntRef(n)
      subtract(nRef)
      n = nRef.get()
      n
   }
   else
      0
}
~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial #10; //Method java/lang/Object."":()V
   4:   return

private final void subtract$1(scala.runtime.IntRef);
  Code:
   0:   aload_1
   1:   aload_1
   2:   getfield #18; //Field scala/runtime/IntRef.elem:I
   5:   iconst_1
   6:   isub
   7:   putfield #18; //Field scala/runtime/IntRef.elem:I
   10:  return

public int bar(int);
  Code:
   0:   new #14; //class scala/runtime/IntRef
   3:   dup
   4:   iload_1
   5:   invokespecial #23; //Method scala/runtime/IntRef."":(I)V
   8:   astore_2
   9:   aload_2
   10:  getfield #18; //Field scala/runtime/IntRef.elem:I
   13:  iconst_0
   14:  if_icmple 29
   17:  aload_0
   18:  aload_2
   19:  invokespecial #27; //Method subtract$1:(Lscala/runtime/IntRef;)V
   22:  aload_2
   23:  getfield #18; //Field scala/runtime/IntRef.elem:I
   26:  goto 30
   29:  iconst_0
   30:  ireturn

}

Edit: adding first class functions

To get object allocations you need to use functions in a more first class manner

def sumWithFunction(n : Int, f : Int => Int) : Int = {
  if(n > 0)
    f(n) + sumWithFunction(n - 1, f)
  else
    0               
}  

def sumDoubles(n: Int) : Int = {
  def dbl(a: Int) = 2 * a
  sumWithFunction(n, dbl)
}

That desugars into something a bit like

def sumWithFunction(n : Int, f : Int => Int) : Int = {
  if(n > 0)
    f(n) + sumWithFunction(n - 1, f)
  else
    0               
}  

private[this] def dbl(a: Int) = 2 * a

def sumDoubles(n: Int) : Int = {
  sumWithFunction(n, new Function0[Int,Int] {
    def apply(x : Int) = dbl(x)
  })
}

Here's the byte code

~/test$ javap -private -c Foo
Compiled from "test.scala"
public class Foo extends java.lang.Object implements scala.ScalaObject{
public Foo();
  Code:
   0:   aload_0
   1:   invokespecial #10; //Method java/lang/Object."":()V
   4:   return

public final int dbl$1(int);
  Code:
   0:   iconst_2
   1:   iload_1
   2:   imul
   3:   ireturn

public int sumDoubles(int);
  Code:
   0:   aload_0
   1:   iload_1
   2:   new #20; //class Foo$$anonfun$sumDoubles$1
   5:   dup
   6:   aload_0
   7:   invokespecial #23; //Method Foo$$anonfun$sumDoubles$1."":(LFoo;)V
   10:  invokevirtual #29; //Method sumWithFunction:(ILscala/Function1;)I
   13:  ireturn

public int sumWithFunction(int, scala.Function1);
  Code:
   0:   iload_1
   1:   iconst_0
   2:   if_icmple 30
   5:   aload_2
   6:   iload_1
   7:   invokestatic #36; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
   10:  invokeinterface #42,  2; //InterfaceMethod scala/Function1.apply:(Ljava/lang/Object;)Ljava/lang/Object;
   15:  invokestatic #46; //Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
   18:  aload_0
   19:  iload_1
   20:  iconst_1
   21:  isub
   22:  aload_2
   23:  invokevirtual #29; //Method sumWithFunction:(ILscala/Function1;)I
   26:  iadd
   27:  goto 31
   30:  iconst_0
   31:  ireturn

}

~/test$ javap -private -c "Foo\$\$anonfun\$sumDoubles\$1"
Compiled from "test.scala"
public final class Foo$$anonfun$sumDoubles$1 extends java.lang.Object implements scala.Function1,scala.ScalaObject,java.io.Serializable{
private final Foo $outer;

public Foo$$anonfun$sumDoubles$1(Foo);
  Code:
   0:   aload_1
   1:   ifnonnull 12
   4:   new #10; //class java/lang/NullPointerException
   7:   dup
   8:   invokespecial #13; //Method java/lang/NullPointerException."":()V
   11:  athrow
   12:  aload_0
   13:  aload_1
   14:  putfield #17; //Field $outer:LFoo;
   17:  aload_0
   18:  invokespecial #20; //Method java/lang/Object."":()V
   21:  aload_0
   22:  invokestatic #26; //Method scala/Function1$class.$init$:(Lscala/Function1;)V
   25:  return

public final java.lang.Object apply(java.lang.Object);
  Code:
   0:   aload_0
   1:   getfield #17; //Field $outer:LFoo;
   4:   astore_2
   5:   aload_0
   6:   aload_1
   7:   invokestatic #37; //Method scala/runtime/BoxesRunTime.unboxToInt:(Ljava/lang/Object;)I
   10:  invokevirtual #40; //Method apply:(I)I
   13:  invokestatic #44; //Method scala/runtime/BoxesRunTime.boxToInteger:(I)Ljava/lang/Integer;
   16:  areturn

public final int apply(int);
  Code:
   0:   aload_0
   1:   getfield #17; //Field $outer:LFoo;
   4:   astore_2
   5:   aload_0
   6:   getfield #17; //Field $outer:LFoo;
   9:   iload_1
   10:  invokevirtual #51; //Method Foo.dbl$1:(I)I
   13:  ireturn

public scala.Function1 andThen(scala.Function1);
  Code:
   0:   aload_0
   1:   aload_1
   2:   invokestatic #56; //Method scala/Function1$class.andThen:(Lscala/Function1;Lscala/Function1;)Lscala/Function1;
   5:   areturn

public scala.Function1 compose(scala.Function1);
  Code:
   0:   aload_0
   1:   aload_1
   2:   invokestatic #60; //Method scala/Function1$class.compose:(Lscala/Function1;Lscala/Function1;)Lscala/Function1;
   5:   areturn

public java.lang.String toString();
  Code:
   0:   aload_0
   1:   invokestatic #65; //Method scala/Function1$class.toString:(Lscala/Function1;)Ljava/lang/String;
   4:   areturn

}

The anonymous class gets a lot of code copied in from the Function1 trait. That does have a cost in terms of class loading overhead, but doesn't affect the cost of allocating the object or executing the code. The other cost is the boxing and unboxing of the integer. There's hope that that cost will go away with 2.8's @specialized annotation.

James Iry
Is your aim to turn into the Jon Skeet of Scala? :)
skaffman
You're right! I should have reached for my Java toolkit earlier. To be honest I really expected more of a dramatic transformation of the code on the part of scalac. For some reason I assumed that there would be more object wrappers for everything, especially for the functions. The resulting Java bytecode is actually quite readable... not beautiful... but readable.Thank you for taking the time to consider all three cases (enclosing mutable variables, mutable variables and no variables) I should probably have mentioned that in the original question.
DuncanACoulter
After first class function update: Ah OK, class loading overhead I can live with. Can't say that the thought of an annotation based solution fills me with excitement, but that's probably me talking through my ignorance.
DuncanACoulter