views:

374

answers:

4

Is it possible to combine memoization and tail-recursion somehow? I'm learning F# at the moment and understand both concepts but can't seem to combine them.

Suppose I have the following memoize function (from Real-World Functional Programming):

let memoize f = let cache = new Dictionary<_, _>()
                (fun x -> match cache.TryGetValue(x) with
                          | true, y -> y
                          | _       -> let v = f(x)
                                       cache.Add(x, v)
                                       v)

and the following factorial function:

let rec factorial(x) = if (x = 0) then 1 else x * factorial(x - 1)

Memoizing factorial isn't too difficult and making it tail-recursive isn't either:

let rec memoizedFactorial =
  memoize (fun x -> if (x = 0) then 1 else x * memoizedFactorial(x - 1))

let tailRecursiveFactorial(x) =
  let rec factorialUtil(x, res) = if (x = 0)
                                  then res
                                  else let newRes = x * res
                                       factorialUtil(x - 1, newRes)
  factorialUtil(x, 1)

But can you combine memoization and tail-recursion? I made some attempts but can't seem to get it working. Or is this simply not possible?

+3  A: 

I'm not sure if there's a simpler way to do this, but one approach would be to create a memoizing y-combinator:

let memoY f =
  let cache = Dictionary<_,_>()
  let rec fn x =
    match cache.TryGetValue(x) with
    | true,y -> y
    | _ -> let v = f fn x
           cache.Add(x,v)
           v
  fn

Then, you can use this combinator in lieu of "let rec", with the first argument representing the function to call recursively:

let tailRecFact =
  let factHelper fact (x, res) = 
    printfn "%i,%i" x res
    if x = 0 then res 
    else fact (x-1, x*res)
  let memoized = memoY factHelper
  fun x -> memoized (x,1)

EDIT

As Mitya pointed out, memoY doesn't preserve the tail recursive properties of the memoee. Here's a revised combinator which uses exceptions and mutable state to memoize any recursive function without overflowing the stack (even if the original function is not itself tail recursive!):

let memoY f =
  let cache = Dictionary<_,_>()
  fun x ->
    let l = ResizeArray([x])
    while l.Count <> 0 do
      let v = l.[l.Count - 1]
      if cache.ContainsKey(v) then l.RemoveAt(l.Count - 1)
      else
        try
          cache.[v] <- f (fun x -> 
            if cache.ContainsKey(x) then cache.[x] 
            else 
              l.Add(x)
              failwith "Need to recurse") v
        with _ -> ()
    cache.[x]

Unfortunately, the machinery which is inserted into each recursive call is somewhat heavy, so performance on un-memoized inputs requiring deep recursion can be a bit slow. However, compared to some other solutions, this has the benefit that it requires fairly minimal changes to the natural expression of recursive functions:

let fib = memoY (fun fib n -> 
  printfn "%i" n; 
  if n <= 1 then n 
  else (fib (n-1)) + (fib (n-2)))

let _ = fib 5000

EDIT

I'll expand a bit on how this compares to other solutions. This technique takes advantage of the fact that exceptions provide a side channel: a function of type 'a -> 'b doesn't actually need to return a value of type 'b, but can instead exit via an exception. We wouldn't need to use exceptions if the return type explicitly contained an additional value indicating failure. Of course, we could use the 'b option as the return type of the function for this purpose. This would lead to the following memoizing combinator:

let memoO f =
  let cache = Dictionary<_,_>()
  fun x ->
    let l = ResizeArray([x])
    while l.Count <> 0 do
      let v = l.[l.Count - 1]
      if cache.ContainsKey v then l.RemoveAt(l.Count - 1)
      else
        match f(fun x -> if cache.ContainsKey x then Some(cache.[x]) else l.Add(x); None) v with
        | Some(r) -> cache.[v] <- r; 
        | None -> ()
    cache.[x]

Previously, our memoization process looked like:

fun fib n -> 
  printfn "%i" n; 
  if n <= 1 then n 
  else (fib (n-1)) + (fib (n-2))
|> memoY

Now, we need to incorporate the fact that fib should return an int option instead of an int. Given a suitable workflow for option types, this could be written as follows:

fun fib n -> option {
  printfn "%i" n
  if n <= 1 then return n
  else
    let! x = fib (n-1)
    let! y = fib (n-2)
    return x + y
} |> memoO

However, if we're willing to change the return type of the first parameter (from int to int option in this case), we may as well go all the way and just use continuations in the return type instead, as in Brian's solution. Here's a variation on his definitions:

let memoC f =
  let cache = Dictionary<_,_>()
  let rec fn n k =
    match cache.TryGetValue(n) with
    | true, r -> k r
    | _ -> 
        f fn n (fun r ->
          cache.Add(n,r)
          k r)
  fun n -> fn n id

And again, if we have a suitable computation expression for building CPS functions, we can define our recursive function like this:

fun fib n -> cps {
  printfn "%i" n
  if n <= 1 then return n
  else
    let! x = fib (n-1)
    let! y = fib (n-2)
    return x + y
} |> memoC

This is exactly the same as what Brian has done, but I find the syntax here is easier to follow. To make this work, all we need are the following two definitions:

type CpsBuilder() =
  member this.Return x k = k x
  member this.Bind(m,f) k = m (fun a -> f a k)

let cps = CpsBuilder()
kvb
-1, I don't consider this memoization since it is only storing the final result and not all of the factorials along the way.
gradbot
@gradbot - It stores all of the results along the way; however, because of the nature of the tail-recursive implementation these results are not reused across different top-level computations. That is, when calculating 3!, we store the results for (3,1), (2,3), and (1,6), but when calculating 4! we store the results for (4,1), (3,4), (2,12), and (1,24), none of which already appeared in the 3! calculation.
kvb
@kvb interesting, can you do this with the Fibonacci series then? In the current case it provides no performance advantage.
gradbot
-1. This implementation is NOT tail-recursive. Try for yourself `tailRecFact 100000` fails with SO. Specifically, call from `fn` to `f` is not tailrecursive.
Mitya
@gradbot - Yes, it works for any recursive function, including the Fibonacci series. However, as Mitya pointed out, my original solution did not preserve the tail-recursive behavior. I've added an edit which provides another solution.
kvb
@Mitya - You're absolutely right. However, see my edit for a solution which can memoize any pure recursive function without a stack overflow, even if the original is not tail recursive.
kvb
I guess by 'pure' here you mean 'function that does not contain any exception-handling logic'? Is there a way to use continuations rather than exceptions, without the user's "function guts" knowing? I need to keep thinking about it...
Brian
@Brian - One requirement is that it not catch the exception that I'm throwing. However, additionally, the function should also be idempotent since I may need to call it multiple times with the same argument. However, this was the least invasive implementation I could come up with. I like your CPS solution, but the code for the recursive computation becomes much less readable (to my eyes, at least).
kvb
Note that you do not actually need a cps workflow builder - F# alreadu has one - it is called "async" :)
Mitya
+2  A: 

I wrote a test to visualize the memoization. Each dot is a recursive call.

......720 // factorial 6
......720 // factorial 6
.....120  // factorial 5

......720 // memoizedFactorial 6
720       // memoizedFactorial 6
120       // memoizedFactorial 5

......720 // tailRecFact 6
720       // tailRecFact 6
.....120  // tailRecFact 5

......720 // tailRecursiveMemoizedFactorial 6
720       // tailRecursiveMemoizedFactorial 6
.....120  // tailRecursiveMemoizedFactorial 5

kvb's solution returns the same results are straight memoization like this function.

let tailRecursiveMemoizedFactorial = 
    memoize 
        (fun x ->
            let rec factorialUtil x res = 
                if x = 0 then 
                    res
                else 
                    printf "." 
                    let newRes = x * res
                    factorialUtil (x - 1) newRes

            factorialUtil x 1
        )

Test source code.

open System.Collections.Generic

let memoize f = 
    let cache = new Dictionary<_, _>()
    (fun x -> 
        match cache.TryGetValue(x) with
        | true, y -> y
        | _ -> 
            let v = f(x)
            cache.Add(x, v)
            v)

let rec factorial(x) = 
    if (x = 0) then 
        1 
    else
        printf "." 
        x * factorial(x - 1)

let rec memoizedFactorial =
    memoize (
        fun x -> 
            if (x = 0) then 
                1 
            else 
                printf "."
                x * memoizedFactorial(x - 1))

let memoY f =
  let cache = Dictionary<_,_>()
  let rec fn x =
    match cache.TryGetValue(x) with
    | true,y -> y
    | _ -> let v = f fn x
           cache.Add(x,v)
           v
  fn

let tailRecFact =
  let factHelper fact (x, res) = 
    if x = 0 then 
        res 
    else
        printf "." 
        fact (x-1, x*res)
  let memoized = memoY factHelper
  fun x -> memoized (x,1)

let tailRecursiveMemoizedFactorial = 
    memoize 
        (fun x ->
            let rec factorialUtil x res = 
                if x = 0 then 
                    res
                else 
                    printf "." 
                    let newRes = x * res
                    factorialUtil (x - 1) newRes

            factorialUtil x 1
        )

factorial 6 |> printfn "%A"
factorial 6 |> printfn "%A"
factorial 5 |> printfn "%A\n"

memoizedFactorial 6 |> printfn "%A"
memoizedFactorial 6 |> printfn "%A"
memoizedFactorial 5 |> printfn "%A\n"

tailRecFact 6 |> printfn "%A"
tailRecFact 6 |> printfn "%A"
tailRecFact 5 |> printfn "%A\n"

tailRecursiveMemoizedFactorial 6 |> printfn "%A"
tailRecursiveMemoizedFactorial 6 |> printfn "%A"
tailRecursiveMemoizedFactorial 5 |> printfn "%A\n"

System.Console.ReadLine() |> ignore
gradbot
Thanks for putting so much effort into this answer. But is it really impossible to combine the memoization of 'memoizedFactorial' with tail recursion? In 'memoizedFactorial' intermediate results are memoized and can be reused later. When tail recursion is added, suddenly only the end result is memoized. Leading to a complete recalculation for 'tailRecursiveMemoizedFactorial 5' (that we have already calculated).
Ronald Wildenberg
@Ronald you're welcome. I figured it was only a matter of time before Brian or someone else posted a correct answer. I had to work otherwise I would have taken another stab at it. This is also a great algorithm question.
gradbot
+9  A: 

The predicament of memoizing tail-recursive functions is, of course, that when tail-recursive function

let f x = 
   ......
   f x1

calls itself, it is not allowed to do anything with a result of the recursive call, including putting it into cache. Tricky; so what can we do?

The critical insight here is that since the recursive function is not allowed to do anything with a result of recursive call, the result for all arguments to recursive calls will be the same! Therefore if recursion call trace is this

f x0 -> f x1 -> f x2 -> f x3 -> ... -> f xN -> res

then for all x in x0,x1,...,xN the result of f x will be the same, namely res. So the last invocation of a recursive function, the non-recursive call, knows the results for all the previous values - it is in a position to cache them. The only thing you need to do is to pass a list of visited values to it. Here is what it might look for factorial:

let cache = Dictionary<_,_>()

let rec fact0 l ((n,res) as arg) = 
    let commitToCache r = 
        l |> List.iter  (fun a -> cache.Add(a,r))
    match cache.TryGetValue(arg) with
    |   true, cachedResult -> commitToCache cachedResult; cachedResult
    |   false, _ ->
            if n = 1 then
                commitToCache res
                cache.Add(arg, res)
                res
            else
                fact0 (arg::l) (n-1, n*res)

let fact n = fact0 [] (n,1)

But wait! Look - l parameter of fact0 contains all the arguments to recursive calls to fact0 - just like the stack would in a non-tail-recursive version! That is exactly right. Any non-tail recursive algorithm can be converted to a tail-recursive one by moving the "list of stack frames" from stack to heap and converting the "postprocessing" of recursive call result into a walk over that data structure.

Pragmatic note: The factorial example above illustrates a general technique. It is quite useless as is - for factorial function it is quite enough to cache the top-level fact n result, because calculation of fact n for a particular n only hits a unique series of (n,res) pairs of arguments to fact0 - if (n,1) is not cached yet, then none of the pairs fact0 is going to be called on are.

Note that in this example, when we went from non-tail-recursive factorial to a tail-recursive factorial, we exploited the fact that multiplication is associative and commutative - tail-recursive factorial execute a different set of multiplications than a non-tail-recursive one.

In fact, a general technique exists for going from non-tail-recursive to tail-recursive algorithm, which yields an algorithm equivalent to a tee. This technique is called "continuatuion-passing transformation". Going that route, you can take a non-tail-recursive memoizing factorial and get a tail-recursive memoizing factorial by pretty much a mechanical transformation. See Brian's answer for exposition of this method.

Mitya
I think we can do even better than this (as we discussed offline); one moment...
Brian
Thanks very much for this very thorough explanation. It's taking me some time to understand all you're saying :) If I understand correctly, you loose memoization of intermediate values with this implementation. Not entirely correct, they are cached but can't be used. If I first call 'fact 4' and then 'fact 5', the second call will have to calculate everything all over again without being able to use the result of 'fact 4'.
Ronald Wildenberg
@Ronald yes that is correct. If you try to memoize textbook tail-recursive factorial that is what you get because that implementation, when invoked for (n,1), never calls itself for (k,1) or any other pair in (k,1) trace. To get cache reuse, you have to change the way you "tail-recurse" your factorial.
Mitya
+10  A: 

As always, continuations yield an elegant tailcall solution:

open System.Collections.Generic 

let cache = Dictionary<_,_>()  // TODO move inside 
let memoizedTRFactorial =
    let rec fac n k =  // must make tailcalls to k
        match cache.TryGetValue(n) with
        | true, r -> k r
        | _ -> 
            if n=0 then
                k 1
            else
                fac (n-1) (fun r1 ->
                    printfn "multiplying by %d" n  //***
                    let r = r1 * n
                    cache.Add(n,r)
                    k r)
    fun n -> fac n id

printfn "---"
let r = memoizedTRFactorial 4
printfn "%d" r
for KeyValue(k,v) in cache do
    printfn "%d: %d" k v

printfn "---"
let r2 = memoizedTRFactorial 5
printfn "%d" r2

printfn "---"

// comment out *** line, then run this
//let r3 = memoizedTRFactorial 100000
//printfn "%d" r3

There are two kinds of tests. First, this demos that calling F(4) caches F(4), F(3), F(2), F(1) as you would like.

Then, comment out the *** printf and uncomment the final test (and compile in Release mode) to show that it does not StackOverflow (it uses tailcalls correctly).

Perhaps I'll generalize out 'memoize' and demonstrate it on 'fib' next...

EDIT

Ok, here's the next step, I think, decoupling memoization from factorial:

open System.Collections.Generic 

let cache = Dictionary<_,_>()  // TODO move inside 
let memoize fGuts n =
    let rec newFunc n k =  // must make tailcalls to k
        match cache.TryGetValue(n) with
        | true, r -> k r
        | _ -> 
            fGuts n (fun r ->
                        cache.Add(n,r)
                        k r) newFunc
    newFunc n id 
let TRFactorialGuts n k memoGuts =
    if n=0 then
        k 1
    else
        memoGuts (n-1) (fun r1 ->
            printfn "multiplying by %d" n  //***
            let r = r1 * n
            k r) 

let memoizedTRFactorial = memoize TRFactorialGuts 

printfn "---"
let r = memoizedTRFactorial 4
printfn "%d" r
for KeyValue(k,v) in cache do
    printfn "%d: %d" k v

printfn "---"
let r2 = memoizedTRFactorial 5
printfn "%d" r2

printfn "---"

// comment out *** line, then run this
//let r3 = memoizedTRFactorial 100000
//printfn "%d" r3

EDIT

Ok, here's a fully generalized version that seems to work.

open System.Collections.Generic 

let memoize fGuts =
    let cache = Dictionary<_,_>()
    let rec newFunc n k =  // must make tailcalls to k
        match cache.TryGetValue(n) with
        | true, r -> k r
        | _ -> 
            fGuts n (fun r ->
                        cache.Add(n,r)
                        k r) newFunc
    cache, (fun n -> newFunc n id)
let TRFactorialGuts n k memoGuts =
    if n=0 then
        k 1
    else
        memoGuts (n-1) (fun r1 ->
            printfn "multiplying by %d" n  //***
            let r = r1 * n
            k r) 

let facCache,memoizedTRFactorial = memoize TRFactorialGuts 

printfn "---"
let r = memoizedTRFactorial 4
printfn "%d" r
for KeyValue(k,v) in facCache do
    printfn "%d: %d" k v

printfn "---"
let r2 = memoizedTRFactorial 5
printfn "%d" r2

printfn "---"

// comment out *** line, then run this
//let r3 = memoizedTRFactorial 100000
//printfn "%d" r3

let TRFibGuts n k memoGuts =
    if n=0 || n=1 then
        k 1
    else
        memoGuts (n-1) (fun r1 ->
            memoGuts (n-2) (fun r2 ->
                printfn "adding %d+%d" r1 r2 //%%%
                let r = r1+r2
                k r)) 
let fibCache, memoizedTRFib = memoize TRFibGuts 
printfn "---"
let r5 = memoizedTRFib 4
printfn "%d" r5
for KeyValue(k,v) in fibCache do
    printfn "%d: %d" k v

printfn "---"
let r6 = memoizedTRFib 5
printfn "%d" r6

printfn "---"

// comment out %%% line, then run this
//let r7 = memoizedTRFib 100000
//printfn "%d" r7
Brian
Too much new stuff in one day... :) Thanks. I hope I understand how this thing works. Stepping it through the debugger now.
Ronald Wildenberg
Ok, I'm officially lost :D Maybe I should do some additional reading on continuations in general and F# in particular. Three weeks in Functional Programming Land just isn't enough. But since your code does exactly what I asked for, this must be the answer I was looking for.
Ronald Wildenberg
Yeah, I will maybe try to blog it. I don't really understand it either, I have just done enough of it that my fingers know how to type code that works :) I should try to blog it, as that will force my brain to understand it well enough to articulate what the heck I am doing.
Brian
Ah, that's reassuring :)
Ronald Wildenberg
I've seen a similar technique used to fold across trees.
gradbot
Does this replace stack frames with heap-allocated closures? If so, it seem similar to Mitya's solution that passes the list of arguments forward until the result is known. Stack is more precious than heap, but still, this could use a lot of memory and could possibly run out for large values... right?
Jason
@Jason: yes, this allocates closures at each lambda (`fun`), however (unlike @Mitya) they are allocated and freed linearly, so the memory use here is constant, not linear (modulo the ever-growing cache). The use pattern is also nearly optimal for the GC, I think, in that there are lots of short-lived, small allocations. So I think heap memory pressure is a non-issue here, though I have not profiled it to verify.
Brian
I like this answer, although the inverted definition of functions like `TRFibGuts` is somewhat offputting (not to mention the name :). I think you could make this look much nicer using a CPS workflow. You'd have something like `let fibDefn fib n = cps { if n <= 1 then return n else let! x = fib (n-1) in let! y = fib (n-2) in return x + y }`. (This would also require a minor reordering of some of the parameters to `memoize` and its argument).
kvb
I've added an example with a CPS builder to the end of my answer for illustration, but it's just a very minor variation on what you've done.
kvb
@Brian: memory use in your solution is not constant - you still allocate n closures when calculating factorial of n. CPS transform is essentially moving your stack onto heap. If your original function has a stack depth of O(n), you will get O(n) heap usage after CPS transform.
Mitya
@Jason: yes you are right - CPS transform is precisely replacing stack frames with heap-allocated closures
Mitya
@Mitya, I don't think so (but maybe I am wrong); at each stage, this allocates a single closure, and calls it, the next call is a tail call, so when that heap-allocated closure is about to call the next `k`, its own stack frame (which contains the only reference to the closure) is removed, thus making the just-finished closure eligible for GC, right? (Analogous to Haskell, where stack is not an issue, but you use tail calls to prevent space leaks, just as here.)
Brian
@Brian, no: every closure that you allocate has a reference to a previous continuation closure (all your "fun -> .." reference k). Therefore as you dive into TRFactGuts/memoize you accumulate a chain of closures each referencing the previous closure. That chain has length n. When you reach the base case and finally call the continuation (`k 1`) the chain of closures will start to unwind. Again, generally all that CPS does is move your stack onto heap - CPS never reduces your space requirements.
Mitya
Ah yes, you're right, I see it now.
Brian