views:

195

answers:

3

Each time a function is called, if it's result for a given set of argument values is not yet memoized I'd like to put the result into an in-memory table. One column is meant to store a result, others to store arguments values.

How do I best implement this? Arguments are of diverse types, including some enums.

In C# I'd generally use DataTable. Is there an equivalent in Scala?

+5  A: 

You could use a mutable.Map[TupleN[A1, A2, ..., AN], R] , or if memory is a concern, a WeakHashMap[1]. The definitions below (built on the memoization code from michid's blog) allow you to easily memoize functions with multiple arguments. For example:

import Memoize._

def reallySlowFn(i: Int, s: String): Int = {
   Thread.sleep(3000)
   i + s.length
}

val memoizedSlowFn = memoize(reallySlowFn _)
memoizedSlowFn(1, "abc") // returns 4 after about 3 seconds
memoizedSlowFn(1, "abc") // returns 4 almost instantly

Definitions:

/**
 * A memoized unary function.
 *
 * @param f A unary function to memoize
 * @param [T] the argument type
 * @param [R] the return type
 */
class Memoize1[-T, +R](f: T => R) extends (T => R) {
   import scala.collection.mutable
   // map that stores (argument, result) pairs
   private[this] val vals = mutable.Map.empty[T, R]

   // Given an argument x, 
   //   If vals contains x return vals(x).
   //   Otherwise, update vals so that vals(x) == f(x) and return f(x).
   def apply(x: T): R = vals getOrElseUpdate (x, f(x))
}

object Memoize {
   /**
    * Memoize a unary (single-argument) function.
    *
    * @param f the unary function to memoize
    */
   def memoize[T, R](f: T => R): (T => R) = new Memoize1(f)

   /**
    * Memoize a binary (two-argument) function.
    * 
    * @param f the binary function to memoize
    * 
    * This works by turning a function that takes two arguments of type
    * T1 and T2 into a function that takes a single argument of type 
    * (T1, T2), memoizing that "tupled" function, then "untupling" the
    * memoized function.
    */
   def memoize[T1, T2, R](f: (T1, T2) => R): ((T1, T2) => R) = 
      Function.untupled(memoize(f.tupled))

   /**
    * Memoize a ternary (three-argument) function.
    *
    * @param f the ternary function to memoize
    */
   def memoize[T1, T2, T3, R](f: (T1, T2, T3) => R): ((T1, T2, T3) => R) =
      Function.untupled(memoize(f.tupled))

   // ... more memoize methods for higher-arity functions ...

   /**
    * Fixed-point combinator (for memoizing recursive functions).
    */
   def Y[T, R](f: (T => R) => T => R): (T => R) = {
      lazy val yf: (T => R) = memoize(f(yf)(_))
      yf
   }
}

The fixed-point combinator (Memoize1.Y) makes it possible to memoize recursive functions:

val fib: BigInt => BigInt = {                         
   def fibRec(f: BigInt => BigInt)(n: BigInt): BigInt = {
      if (n == 0) 1 
      else if (n == 1) 1 
      else (f(n-1) + f(n-2))                           
   }                                                     
   Memoize.Y(fibRec)
}

[1] WeakHashMap does not work well as a cache. See http://www.codeinstructions.com/2008/09/weakhashmap-is-not-cache-understanding.html and this related question.

Aaron Novstrup
+5  A: 

The version suggested by anovstrup using a mutable Map is basically the same as in C#, and therefore easy to use.

But if you want you can also use a more functional style as well. It uses immutable maps, which act as a kind of accumalator. Having Tuples (instead of Int in the example) as keys works exactly as in the mutable case.

def fib(n:Int) = fibM(n, Map(0->1, 1->1))._1

def fibM(n:Int, m:Map[Int,Int]):(Int,Map[Int,Int]) = m.get(n) match {
   case Some(f) => (f, m)
   case None => val (f_1,m1) = fibM(n-1,m)
                val (f_2,m2) = fibM(n-2,m1)
                val f = f_1+f_2
                (f, m2 + (n -> f))   
}

Of course this is a little bit more complicated, but a useful technique to know (note that the code above aims for clarity, not for speed).

Landei
+2  A: 

Being a newbie in this subject, I could fully understand none of the examples given (but would like to thank anyway). Respectfully, I'd present my own solution for the case some one comes here having a same level and same problem. I think my code can be crystal clear for anybody having just the very-very basic Scala knowledge.



def MyFunction(dt : DateTime, param : Int) : Double
{
  val argsTuple = (dt, param)
  if(Memo.contains(argsTuple)) Memo(argsTuple) else Memoize(dt, param, MyRawFunction(dt, param))
}

def MyRawFunction(dt : DateTime, param : Int) : Double
{
  1.0 // A heavy calculation/querying here
}

def Memoize(dt : DateTime, param : Int, result : Double) : Double
{
  Memo += (dt, param) -> result
  result
}

val Memo = new  scala.collection.mutable.HashMap[(DateTime, Int), Double]


Works perfectly. I'd appreciate critique If I've missed something.

Ivan
@Ivan I added some comments to my solution that will hopefully clarify it for you. The advantage of the approach I've outlined is that it allows you to memoize *any* function (ok, there are some caveats, but *many functions*). Sort of like the memoize keyword you posted about in a related question.
Aaron Novstrup
The one aspect that probably remains mystifying is the fixed-point combinator -- for that I encourage you to read michid's blog, drink lots of coffee, and maybe get friendly with some functional programming texts. The good news is that you only need it if you're memoizing a recursive function.
Aaron Novstrup