Defunctionalize the Continuation

Tags: haskell
May 15 2020
By Li-yao Xia

This post details a little example of refactoring a program using defunctionalization.

Defunctionalization

Defunctionalization is a programming technique that emulates higher-order functions using only first-order language features.

A higher-order function (HOF) can be defunctionalized in two steps:

  1. for every location where the HOF is used and applied to a function, replace that function with a value of a dedicated data type to represent that function—we call this a “defunctionalization symbol”; and
  2. in the HOF’s definition, wherever a function parameter is applied, replace that application with a call to a dedicated first-order function which will interpret the defunctionalization symbol that this parameter now stands for—this function is often called apply, sometimes eval.

For a more fulfilling introduction, check out The best refactoring you’ve never heard of, by James Koppel; the title of this post comes from that talk. We are about to see another example of that refactoring technique.

Defunctionalization makes functional programming ideas—from a world where higher-order functions are the norm—applicable even in first-order languages. However, defunctionalization also has a lot of value in languages that already have higher-order functions, i.e., functional languages. This is what we wish to demonstrate.

The Problem

Here is the well-known sum function which computes the sum of a list of numbers. The code is in Haskell for concreteness, but this could be written in pretty much any programming language.

sum :: [Int] -> Int
sum []       = 0
sum (x : xs) = x + sum xs

This is a basic example of a recursive function that also illustrates one of their common pitfalls.

If we evaluate it, say, on a list 1 : 2 : 3 : [], we can see two distinct phases in the computation: first, the recursive calls are unfolded, and only after all of that unfolding do we start adding the numbers in the list, starting from the end. Hence, during the first phase, the function sum will consume additional space that grows linearly with the length of the input list.

  sum (1 : 2 : 3 : [])
= 1 + sum (2 : 3 : [])
= 1 + (2 + (sum (3 : [])))
= 1 + (2 + (3 + sum []))
= 1 + (2 + (3 + 0))         -- end of the first phase
= 1 + (2 + 3)
= 1 + 5
= 6

As you might already be aware, the solution is to rewrite the sum function to be tail-recursive. This is a common technique to allow functional programs to compile down to efficient machine code, as efficient as a regular loop in imperative programming.

In fact, this simple transformation can be decomposed into two even more elementary and (at least partially) mechanical steps. We will illustrate them using the very small example of sum, but the underlying ideas are readily applicable to bigger refactoring and optimization problems.

If we really only cared about making the function tail-recursive, we might do so in a fully mechanical manner:

  1. Rewrite the function in continuation-passing style (CPS).

As we will soon see, tail recursion alone is technically not sufficient to make sum run more efficiently. Another transformation is necessary. What makes CPS valuable is that it enables further code transformations.

The rest of this post is going to illustrate that defunctionalization provides a reasonable second step:

  1. Defunctionalize the continuation.

On problem-solving

In general, problems have solutions made up of two kinds of ideas: “standard strategies” are general methods to decompose problems, and “creative tactics” are domain-specific insights to push a solution to completion.

The two steps presented here thus constitute a “standard strategy” to make functions tail-recursive—and efficient. As we will soon see, only the second step calls for “creative tactics” (using knowledge specific to “adding numbers”).


As a reminder, here is the sum function again:

-- Reminder
sum :: [Int] -> Int
sum []       = 0
sum (x : xs) = x + sum xs

Continuation-passing style

This first step is entirely mechanical, and those familiar with it are free to skip this section. The following explanation is geared to those who are new to the idea of continuation-passing style.

Instead of producing the result directly, we construct a function whose parameter k is to be called with the result. k is commonly named “continuation”, because it stands for what the “rest of the world” is going to do with the result of our function.1

We can already give the type of the transformed function, where a continuation has type Int -> r for an abstract type parameter r. The only way for the function to terminate is to call the continuation:

sum' :: [Int] -> (Int -> r) -> r

In the base case [], of course, we call the continuation to return2 the known result 0:

sum' [] k = k 0

In the inductive step (x : xs), we first compute the sum of the tail xs recursively, via sum' xs. Now that we’re in CPS, sum' xs is not literally the sum of xs: it is another function expecting a continuation. This continuation is defined here as a lambda: provided a result y, adding x to it yields the result we expect from sum' (x : xs), so we pass that to its own continuation k:

sum' (x : xs) k = sum' xs (\y -> k (x + y))

To use the “rest of the world” analogy again, the expression k (x + y) is how we tell the rest of the world that the result is x + y. Moreover, from the perspective of the recursive call sum' xs, what the surrounding call sum' (x : xs) wants to do (add x to the result) is also part of “the rest of the world”, so it makes sense that this logic goes into the continuation of sum' xs.

Put those three lines together:

sum' :: [Int] -> (Int -> r) -> r
sum' []       k = k 0
sum' (x : xs) k = sum' xs (\y -> k (x + y))

In continuation-passing style, all function calls are tail calls: when the recursive call sum' xs ... returns (some value of type r), it can return directly to the caller of sum' (x : xs) k. Thanks to that, tail calls can be compiled to efficient code, which is one of the reasons that make continuation-passing style valuable, for both compiling and programming.

Defunctionalize the continuation

This second step cannot be fully automated: we will need to think creatively in order to simplify the structure of the continuation as much as possible.

We’ve just rewritten the original and naive sum function into sum' by CPS transformation. Let’s take a look at how sum' evaluates on an example. At every step, sum' adds the next element of the list (its first argument) to the continuation (its second argument):

  sum' (1 : 2 : 3 : []) k0
= sum'     (2 : 3 : []) (\y -> k0 (1 + y))
= sum'         (3 : []) (\y -> k0 (1 + (2 + y)))
= sum'              []  (\y -> k0 (1 + (2 + (3 + y))))
= k0 (1 + (2 + (3 + 0)))
= k0 6

Evaluation still blows up by spelling out the whole sum in the continuation. That’s how we can tell that we are not done yet, and further effort is necessary.

But notice also that the continuation always takes the same form at every step:

(\y -> k0 (1 + y))
(\y -> k0 (1 + (2 + y)))
(\y -> k0 (1 + (2 + (3 + y))))

Now, these continuations can be simplified:

(\y -> k0 (1 + y))
(\y -> k0 (3 + y))
(\y -> k0 (6 + y))

In other words, the continuations actually used by this program really consist only of:

  1. an initial continuation k0;
  2. an integer n to add to the final result y.

The continuation k0 is fixed throughout one top-level invocation of sum', so we will treat it as a global constant from the point of view of sum'.

That leaves the integer n, which is really the only data needed to describe each continuation that occurs during the evaluation of sum'.

Defunctionalization replaces continuations by the corresponding integers. Skipping some details in the transformation, where we manage to remove any reference to k0, we are left with the following:

sum_ :: [Int] -> Int -> Int
sum_ []       n = apply n 0
sum_ (x : xs) n = sum_ xs (n + x)

apply :: Int -> Int -> Int
apply n y = n + y

Even though the initial continuation k0 disappeared, n still stands for some sort of continuation; in that sense, sum_ is still in CPS and we didn’t actually undo the previous step.

Side-by-side comparison

If we look carefully enough, sum' and sum_ should really appear as the same thing “modulo defunctionalization”. For comparison, here is sum' again:

-- Reminder
sum' :: [Int] -> (Int -> r) -> r
sum' []       k = k 0
sum' (x : xs) k = sum' xs (\y -> k (x + y))

Where we had a continuation k in sum', we now have a number n to represent it in sum_, assuming that k has the form \y -> k0 (n + y).

Where we had an application k 0 in sum', we now have apply n 0 in sum_. The function apply is defined so that k is equal to \y -> k0 (apply n y), or equivalently, k0 . apply n, so apply = (+). Technically, if we wanted to replace k 0 with something literally equal, we would write k0 (apply n 0), but we secretly dropped k0 at some point in the (skipped) derivation of sum_. Even so, apply formally relates n to the k it defunctionalizes.

Where we had a continuation \y -> k (x + y) in the second case of sum', we now have (n + x) in sum_. They are also related by apply: the continuation \y -> k (x + y) is equal to k0 . apply (n + x), under the assumption that k is equal to k0 . apply n.

Finishing touch

All that remains is to clean up sum_ in the base case: inline apply and simplify n + 0 to n. We obtain the final version, the optimized sum'':

sum'' :: [Int] -> Int -> Int
sum'' []       n = n
sum'' (x : xs) n = sum'' xs (n + x)

Nice and tidy.

  sum'' (1 : 2 : 3 : []) 0
= sum''     (2 : 3 : []) 1
= sum''         (3 : []) 3
= sum''              []  6
= 6

Exercises for the reader:

  1. Some facts about arithmetic are crucial to allow this optimization to take place. Where were they used above?

  2. What would happen if sum' were defunctionalized naively, without simplifying the continuation beforehand?

Continuations as evaluation contexts

A fair question to ask is whether all of this is not a bit indirect. Indeed, to go from the naive sum to the tail-recursive sum'', CPS is not necessarily the first solution that comes to mind. A more direct way to think about the problem is to look again at how sum evaluates (instead of sum'):

  sum (1 : 2 : 3 : [])
= 1 + sum (2 : 3 : [])
= 1 + (2 + sum (3 : []))
= 1 + (2 + (3 + sum []))

and notice that the evaluation contexts around sum have a common shape:

1 + _
1 + (2 + _)
1 + (2 + (3 + _))

It is hopefully evident here that evaluation contexts in sum play the same role as continuations in sum'. Nevertheless, to go from recognizing the shape of the evaluation context to coming up with the optimized sum'', there remains a small gap. How do we logically finish the story from this point?

  1. We must compress such an evaluation context to a plain number.

  2. We must carry explicitly the compressed evaluation context to achieve tail recursion.

These two steps correspond exactly to defunctionalization and CPS transformation, just in the reverse order of what we detailed previously.

It might be that spelling out these remaining steps is more trouble than necessary for such a trivial example, but it’s nice to know that we can, and that this technique now has a name.

Besides, “defunctionalize the evaluation context” doesn’t sound nearly as catchy as “defunctionalize the continuation”.


References


  1. “The rest of the world” may sound grand and daunting. The idea of continuations turns out to be fittingly grand and daunting.↩︎

  2. return is another cute way to name a continuation.↩︎