tags:

views:

89

answers:

1

Hey all,

I am having some problems with trying to implement Automatic Differentiation in F#. I think the problem is down to the evaluation not being 'lazy'.

Here is my code:

type Diff =
    {d : double; df : Diff}
    static member (+) (x : Diff, y : Diff) =
        {d = x.d + y.d; df = x.df + y.df}
    static member (-) (x : Diff, y : Diff) =
        {d = x.d - y.d; df = x.df - y.df}
    static member (*) (x : Diff, a : double) =
        {d = x.d * a; df = x.df * a}
    static member (*) (x : Diff, y : Diff) =
        {d = x.d * y.d; df = (x.df * y) + (y.df * x)}

let rec dZero = {d = 0.0; df = dZero}

let dConst x = {d = x; df = dZero}

let dId x = {d = x; df = dConst 1.0}

let test = dId 5.0

let add (x:Diff) = (x+x).d

If I try to use 'add test' I get a stack overflow error, which I think is down to the definition of (+) inside my type itself relying on '+'.

Is there any way I can fix this? Any help would be greatly appreciated.

Many thanks, Ash

+5  A: 

As you thought, the problem is that the F# doesn't use lazy evaluation and that the data structure you're creating is "infinite" (because dZero recursively references itself). When calculating the +, the operator calls + on the df values and that in turn invokes + on the df.df values and so on...

One way to correct this is to make the df member of the record explicitly lazy:

type Diff = 
    {d : double; df : Lazy<Diff>} 
    static member (+) (x : Diff, y : Diff) = 
        {d = x.d + y.d; df = lazy (x.df.Value + y.df.Value) } 
    static member (-) (x : Diff, y : Diff) = 
        {d = x.d - y.d; df = lazy (x.df.Value - y.df.Value) } 
    static member (*) (x : Diff, a : double) = 
        {d = x.d * a; df = lazy (x.df.Value * a) } 
    static member (*) (x : Diff, y : Diff) = 
        {d = x.d * y.d; df = lazy ((x.df.Value * y) + (y.df.Value * x)) } 

let rec dZero = {d = 0.0; df = lazy dZero} 
let dConst x = {d = x; df = lazy dZero} 
let dId x = {d = x; df = lazy dConst 1.0} 

This will evaluate the df value only when it is actually used, so the + operation will calculate the value of d and only provide a lazy value for df (which can be evaluated if someone needs it).

Another alternative would be to make the Diff type a discriminated union and represent zero as a special value (rather than as a recursive record), which would work unless you use recursive references for something else. The declaration would be roughly something like:

type Diff = 
    | DiffValue of double * Diff
    | DiffZero 
    static member (+) // etc...

This would make the implementation a bit longer, because you would need to check for the Zero case in all the primitive operations. In this case, you would only create finite data structures (and the operators would process them eagerly).

Tomas Petricek