Continuations for normalising sum types

Normalising lambda calculus gets difficult when sum types are introduced. For normalisation-by-evaluation it means that a single "probe-through" of the program no longer suffices. Continuation-passing style provides a way around this problem.

Papers:

  1. Olivier Danvy "Type-Directed Partial evaluation"
  2. Danko Ilik "Continuation-passing style models complete for intuitionistic logic"
  3. Danvy, Kellel, Puech "Typeful Normalization by Evaluation"

Informal description

We are running a program and then turn the result back into a representation of that program we just ran.

Lets start with basic arithmetic again, so we can illustrate the idea and the problem being solved.

1 + 2

So if we were to turn this into a program that evaluates the result, well that'd be simple right? It's a program that evaluates 1+2, the result is 3. If you turn the program back into a representation of that program, you'd get 3 as the result.

The next thing likely doesn't cause any problems either. It's two numbers in a pair.

(5-1, 0+2)

If you were to evaluate this, the result is some pair of numbers. If you reify it back into a program, you get (4, 2).

Now, how about we step it up and introduce functions.

λx.λf.(f (fst x), snd x)

The type of this program is (a,c) → (a → b) → (b,c). If you evaluate this, you get back a function. How do you reify a function back into a program? It's a clever trick, you see those variables a,b,c. They have to be passed around in the program because we don't really tell what they are. You can put a term into that thing and it doesn't know what to do with it. Lets say such arbitrary term would be annotated as 't

Now that you know what to pass into a, we're able to fill the (a,c) argument. It's going to be ('fst x, 'snd x). These are called reflections.

How do you reflect a function argument (a → b)? Well, we got some function f. We can produce λh.'f $h, where $h is reification of the argument given to the function.

Now if you look at the previous program again, if we'd pass these kind of "reflections" into it as arguments, you may see what's the result. fst ('fst x, 'snd x) would turn into fst x. The program can't tear the abstract variable apart, so you'd retrieve the syntactic term when reifying the program.

Next comes a thing that you might think is surprisingly hard. How about we want to reflect a number? For example, you'd like to normalise this:

λx.x + 10

Ok, if you reflect a number, should it be 0,1, or 2? You could probe that function and retrieve a long case statement..

λx.case x of {0 » 10 | 1 » 11 | 2 » 12...}

But what if x is not accessed at all? Then you'd have a redundant case expression.

We can do a little trick with continuation passing style. It allows to introduce a case statement only if the value is accessed. This is the subject of this post.

Continuation-passing-style is a transformation that can be done to functional computer programs. The trick is that we convert every type to it's "double negation".

a » (a → r) → r

It's a simple transform. If you got 4, it turns to λk.k 4.

If we are reifying a program, the r turns to a term. Now the reflection of a sum type can be described as:

λk.case 't {s x » k (s 'x)}

For instance, in boolean types we'd apply the continuation twice inside the expression to produce the reified "false" and "true" branches.

Of course this would produce infinitely long case statements if you did it with numbers. But we can still do this with anything that can be expressed as a sum type.

If we drop out from generic arithmetic, down to natural numbers, it means you need to introduce a fold instead of case statement. For example your natural number representation could be like this:

zero : nat
succ : nat → nat
natfold : nat → a → (a → a) → a

Now the reflection on natural number would be:

λk.'natfold t (k zero) (λa. (k (succ ?a)))

This seems like it could fail miserably, but I'm not entirely certain. Lets define addition for those numbers and get this example concluded.

a+b = natfold a b succ

In the example λx.x+10 we'd end up with..

λx.natfold x 10 (λy. natfold y 11 ...)

It's an infinite structure so it seems you would need something extra to handle the inductive definitions.

Implementation

Lets try this with simply typed lambda calculus equipped with simple sum types.

data Tm = Var Int
        | Lam Tm
        | Ap Tm Tm
        | Pair Tm Tm
        | Fst Tm
        | Snd Tm
        | Case Tm Tm Tm
        | Inl Tm
        | Inr Tm
        deriving (Show, Eq)

data Ty = Prod Ty Ty
        | Exp Ty Ty
        | Sum Ty Ty
        | Pos Int
        | Neg Int
        deriving (Show, Eq)

The data structures need to be constructed with continuations. Every value that may be retrieved is retrieved through a continuation. This was a bit weird detail of this thing. I'm not sure if I got it right here, but at least the program seem to produce some results with simple examples.

type Cont k a = (a -> k) -> k
type TmV = (Int -> Tm)
data Sem k = Nil
         | Cons (Cont k (Sem k)) (Cont k (Sem k))
         | Closure (Cont k (Sem k) -> Cont k (Sem k))
         | Tag Bool (Cont k (Sem k))
         | Syn TmV

Note we pass a "current level" inside the TmV. This is used to generate de-bruijn levels and then convert them into indices.

Evaluation is what you'd expect, really, except that environment also consists of continuations.

extend :: Cont k a -> [Cont k a] -> [Cont k a]
extend v env = (v:env)

peek :: Int -> [Cont k a] -> Cont k a
peek i env = (env !! i)

eval :: [Cont k (Sem k)] -> Tm -> Cont k (Sem k)
eval env (Var i) k = (peek i env) k
eval env (Lam body) k = k (Closure impl)
    where impl arg k = (eval (extend arg env) body k)
eval env (Ap x y) k = eval env x apply
    where apply (Closure f) = f (eval env y) k
eval env (Pair x y) k = k (Cons (eval env x) (eval env y))
eval env (Fst x) k = eval env x access
    where access (Cons a b) = a k
eval env (Snd x) k = eval env x access
    where access (Cons a b) = b k
eval env (Case a x y) k = eval env a select
    where select (Tag False z) = eval (extend z env) x k
          select (Tag True z)  = eval (extend z env) y k
eval env (Inl x) k = k (Tag False (eval env x))
eval env (Inr x) k = k (Tag True (eval env x))

Reify never were a problem with sum types, though be careful since you don't have protection of the polymorphic variable verifying that you apply the continuation you get.

reify :: Ty -> Sem TmV -> Cont TmV TmV
reify (Prod a b) (Cons x y) k
    = (reify_c a x (\p1 ->
       reify_c b y (\p2 ->
           k (\z -> Pair (p1 z) (p2 z)))))
reify (Exp a b) (Closure f) k
    = k (\z -> Lam (f (reflect b (var z)) (\s -> reify a s id) (z+1)))
reify (Sum a b) (Tag False x) k = reify_c a x (k . fmap Inl)
reify (Sum a b) (Tag True x) k = reify_c b x (k . fmap Inr)
reify (Pos _) (Syn t) k = k t
reify (Neg _) (Syn t) k = k t

In the reflection, only the case statement is different from the usual, unless you don't disregard the de-bruijn levels annotated everywhere.

reflect :: Ty -> TmV -> Cont TmV (Sem TmV)
reflect (Prod a b) t k = k (Cons (reflect a (Fst . t))
                                 (reflect b (Snd . t)))
reflect (Exp a b) t k = k (Closure func)
    where func x j = reify_c b x (\arg -> reflect a (\z -> Ap (t z) (arg z)) j)
reflect (Sum a b) t k = (\z -> Case (t z) (k (Tag False (reflect a (var z))) (z+1))
                                          (k (Tag True (reflect b (var z))) (z+1)))
reflect (Pos _) t k = k (Syn t)
reflect (Neg _) t k = k (Syn t)

It's a common case that you reify something wrapped into a continuation.

reify_c :: Ty -> ((Sem TmV -> TmV) -> t) -> (TmV -> TmV) -> t
reify_c t v k = v (\r -> reify t r k)

This routine remaps the de-bruin levels back to indices.

var :: Int -> TmV
var z x = Var (x-z-1)

And finally we got the normalizer built from these pieces.

nbe :: Ty -> Tm -> Tm
nbe ty tm = eval [] tm (\x -> reify ty x id) 0

It's not bigger than that.

Conclusion

If you were to stratify the continuations in the above program, you could open the case expression to where each sum is introduced. This would slightly improve the strength of the normal forms produced by the normaliser.

Continuations play an interesting role in this algorithm.

Similar posts