badd10de.dev

BDL: Part 4 (Lambdas and closures)


Introduction

A lot has changed since last version: now our interpreter can call primitive functions, allowing us to perform computations such as arithmetic operations, checking the type of a given object, or create new lists from previous ones. For example:

BDL REPL (Press Ctrl-D or Ctrl-C to exit)
bdl> (+ 1 2 (* 2 5) 4)
17

Neat huh? As of v0.4 the following primitive procedures are available. Note that procedures that return a boolean value are suffixed with the ? character. Similarly, procedures that can mutate values will have a ! appended to them:

Unless otherwise stated (Only for quote really), any given expression will be resolved before the parent one. For example, this expression (+ 1 (* 2 3) 4) can be traced as follows:

1. Find symbol + in current environment.
2. Call proc_add with the tail of the list (1 (* 2 3) 4).
    2.1 Eval 1 as itself.
    2.2 Eval (* 2 3):
        2.2.1 Find symbol * in current environment
        2.2.2 Call `proc_mul` with (2 3)
        2.2.3 Return the result of 2 * 3 -> 6
    2.3 Add 1 + 6 -> 7
    2.4 Eval 4 as itself
    2.5 Add 7 + 4 -> 11
3. Eval 11 as itself and return

This is great, but we still have no way of declaring our own variables or functions. This is what we will be addressing today, so without further ado, let’s get started.

Defining and setting variables

Let’s start by allowing the user to assign a symbol to a given object. We will call this procedure proc_define and assign it to the global scope with the symbol def. We want to be careful here, however, since we can have nested environments. In other words, we first try to find the given symbol in the current environment and if we can’t find it, we add a new symbol with the given value:

Object *
proc_define(Environment *env, Object *obj) {
    if (obj == obj_nil || obj->cdr == obj_nil) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_NOT_ENOUGH_ARGS,
        });
        return obj_err;
    }

    Object *symbol = obj->car;
    if (symbol->type != OBJ_TYPE_SYMBOL) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_WRONG_ARG_TYPE,
        });
        return obj_err;
    }

    Object *value = eval(env, obj->cdr->car);
    if (value == obj_err) {
        return obj_err;
    }

    env_add_or_update_current(env, symbol, value);
    return obj_nil;
}

We make use of the env_add_or_update_current function, since we may need to use it in other places of the code:

ssize_t
env_index_current(Environment *env, Object *symbol) {
    for (size_t i = 0; i < env->size; i++) {
        EnvEntry entry = env->buf[i];
        if (obj_eq(symbol, entry.symbol)) {
            return i;
        }
    }
    return -1;
}

void
env_add_or_update_current(Environment *env, Object *symbol, Object *value) {
    ssize_t index = env_index_current(env, symbol);
    if (index == -1) {
        env_add_symbol(env, obj_duplicate(symbol), obj_duplicate(value));
    } else {
        env->buf[index].value = obj_duplicate(value);
    }
}

Similarly, we add the set! procedure, which will try to update an already described value in the current environment or any of the parents recursively.

Object *
proc_set(Environment *env, Object *obj) {
    if (obj == obj_nil || obj->cdr == obj_nil) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_NOT_ENOUGH_ARGS,
        });
        return obj_err;
    }

    Object *symbol = obj->car;
    if (symbol->type != OBJ_TYPE_SYMBOL) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_WRONG_ARG_TYPE,
        });
        return obj_err;
    }

    Object *value = eval(env, obj->cdr->car);
    if (value == obj_err) {
        return obj_err;
    }

    return env_update(env, symbol, value);
}

Here is the update function:

Object *
env_update(Environment *env, Object *symbol, Object *value) {
    while (env != NULL) {
        for (size_t i = 0; i < env->size; i++) {
            EnvEntry entry = env->buf[i];
            if (obj_eq(symbol, entry.symbol)) {
                env->buf[i].value = obj_duplicate(value);
                return obj_nil;
            }
        }
        env = env->parent;
    }
    error_push((Error){
        .type = ERR_TYPE_RUNTIME,
        .value = ERR_SYMBOL_NOT_FOUND,
    });
    return obj_err;
}

You might have noticed a obj_duplicate function, which we currently have because we free all objects not in the current root expression, so we want to make sure we make a copy of them before assigning them to an environment. We will hopefully get rid of this once we have a garbage collector, but we need it for now to keep things working:

Object *
obj_duplicate(Object *obj) {
    Object *copy = obj_err;
    switch (obj->type) {
        case OBJ_TYPE_BOOL:
        case OBJ_TYPE_NIL:
        case OBJ_TYPE_PROCEDURE:
        case OBJ_TYPE_LAMBDA: // TODO: should we duplicate everything inside?
        case OBJ_TYPE_ERR: {
            copy = obj;
        } break;
        case OBJ_TYPE_FIXNUM: {
            copy = make_fixnum(obj->fixnum);
        } break;
        case OBJ_TYPE_SYMBOL: {
            copy = make_symbol((StringView){obj->symbol, obj->symbol_n});
        } break;
        case OBJ_TYPE_STRING: {
            copy = make_string();
            append_string(copy, (StringView){obj->string, obj->string_n});
        } break;
        case OBJ_TYPE_PAIR: {
            Object *root = make_pair(obj_duplicate(obj->car), obj_nil);
            copy = root;
            obj = obj->cdr;
            while (obj != obj_nil) {
                root->cdr = make_pair(obj_duplicate(obj->car), obj_nil);
                root = root->cdr;
                obj = obj->cdr;
            }
        } break;
    }
    return copy;
}

We should be able to declare variables as follows, which should return the value 100:

(def a 20)
(def b 40)
(+ 40 a b)

Note that the value assigned to a variable will be evaluated before assignment, so you must use quotation if you want to store unevaluated code:

(def a (+ 1 2 3))
;; a == 6
(def b '(+ 1 2 3))
;; b == (+ 1 2 3)

This could come handy for metaprogramming, but we also need a way of evaluating a series of operations stored on a variable. We can do this with the proc_eval primitive, which will evaluate a given tree in the current environment. In the previous example, (eval b) will return 6.

Object *
proc_eval(Environment *env, Object *obj) {
    if (obj == obj_nil) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_NOT_ENOUGH_ARGS,
        });
        return obj_err;
    }
    return eval(env, eval(env, obj->car));
}

We now have a way of assigning values to names, but how do we define functions? Well user defined functions will be called lambdas, and we can assign lambda functions to symbols as well. Here is a sneak peak of the syntax:

(def myfun (lambda (a b) (+ a b)))

This will assign the result of evaluating the lambda primitive procedure to the symbol named myfun. If we call (myfun 10 20) we will get a result of 30. Declaring functions is such a common operation that we define some syntactic sugar for it with the fun primitive procedure. The following is equivalent to the previous expression:

(fun myfun (a b) (+ a b))

Let’s now talk about how to implement lambdas in our language.

Lambdas and closures

Motivation

First of all, let’s describe the semantics of a lambda function. First of all, we have the names for the variables we pass to the function. These are called “formal parameters”. This lambda function requires two formal parameters, a and b: (lambda (a b) (+ a b)). When calling a lambda function, the formal parameters will be bound to the values given by the caller, these are called “arguments” or “actual parameters”. Following the previous example, we call the lambda like this: ((lambda (a b) (+ a b)) 40 50). Here, the parameter a will be bound to 40 and the parameter b will be bound to the value 50.

The “body” of a lambda expression is composed of an arbitrary number of expressions, the lambda will return the value of the last evaluated expression. For example:

(lambda (a b)
    (def c 20)
    (def d 40)
    (+ a b c d))

The body of that lambda is composed of 3 expressions and the value returned by the lambda is thus the result of the sum of the two actual parameters and the values 20 and 40:

(def c 20)
(def d 40)
(+ a b c d)

So far this is pretty simple right? But what happens if we evaluate the following series of expressions?

(def a 10)
(fun myfun ()
    (display a)
    (print " --- ")
    (def a 42)
    (display a)
    (newline))
(myfun)
(myfun)

Think about it for a second. The answer actually depends on the context of our language. For example, we could get the following output:

10 --- 42
42 --- 42

This is perfectly fine, but what I actually want for our language is for lambdas to capture the current environment they are in. We want our functions to be closures. Conceptually, a lambda creates a new lexical scope, and so modifications to local variables shouldn’t affect the previous state. Thus, the output with this in mind should be:

10 --- 42
10 --- 42

Note that the captured environment, in our case, is a reference to a list of mappings, and so:

(myfun)
(def a 20)
(myfun)

Will result in:

10 --- 42
20 --- 42

Note that this behaviour might be different, depending if the environment capture by the closure is done by reference or by value. We will stick to the aforementioned reference capture for now, as I prefer the semantics.

Keep in mind that we can still modify the values of variables from within a closure. For example:

(def a 20)
(fun inc-a ()
    (set! a (+ a 1))
    a)
(inc-a) ; Returns 21
(inc-a) ; Returns 22
(inc-a) ; Returns 23

If we define an internal symbol, it will be used instead. We can still capture the previous reference if we want! Isn’t that weird and fun?

(def a 20)
(fun myfun ()
     (def a a)
     (set! a (+ a 1))
     a)
(myfun) ; Returns 21
(myfun) ; Returns 21
(def a 30)
(myfun) ; Returns 31

Closures allow us to be very expressive when programming, and can be a powerful construct. For example, look at this classic example:

(fun make-counter ()
    (def value 0)
    (fun counter ()
        (set! value (+ value 1))
        value)
    counter)
(def counter-a (make-counter))
(def counter-b (make-counter))

(counter-a) ;; -> 1
(counter-b) ;; -> 1
(counter-a) ;; -> 2
(counter-a) ;; -> 3
(counter-a) ;; -> 4
(counter-b) ;; -> 2
(counter-b) ;; -> 3
(counter-b) ;; -> 4

We have a function that creates a closure at every call. Each new closure has it’s own environment for the value symbol, which allow us to create as many counters as we want.

Implementation

Hopefully I managed to convince you that closures and lambdas are really cool, but how do we implement them in C? We already have nested environments, which is a good start. We can add a new object type OBJ_TYPE_LAMBDA that keeps track of the parameters and the body of the function. Note that we need to forward declare the Environment struct to get this to compile:

...
    // OBJ_TYPE_PROCEDURE
    struct Object *(*proc)(struct Environment *env, struct Object *args);

    // OBJ_TYPE_LAMBDA
    struct {
        struct Object *params;
        struct Object *body;
        struct Environment *env;
    };
...

We now can create two primitive procedures, one for the lambda function and another for the syntactic sugar for a named function:

Object *
proc_lambda(Environment *env, Object *obj) {
    if (obj == obj_nil || obj->cdr == obj_nil) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_NOT_ENOUGH_ARGS,
        });
        return obj_err;
    }
    Object *params = obj->car;
    if (params != obj_nil && params->type != OBJ_TYPE_PAIR) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_WRONG_ARG_TYPE,
        });
        return obj_err;
    }
    Object *body = obj->cdr;
    Object *fun = alloc_object(OBJ_TYPE_LAMBDA);
    fun->params = obj_duplicate(params);
    fun->body = obj_duplicate(body);
    fun->env = env;
    return fun;
}

Object *
proc_fun(Environment *env, Object *obj) {
    if (obj == obj_nil || obj->cdr == obj_nil || obj->cdr->cdr == obj_nil) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_NOT_ENOUGH_ARGS,
        });
        return obj_err;
    }

    Object *name = obj->car;
    if (name->type != OBJ_TYPE_SYMBOL) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_WRONG_ARG_TYPE,
        });
        return obj_err;
    }

    Object *params = obj->cdr->car;
    if (params != obj_nil && params->type != OBJ_TYPE_PAIR) {
        error_push((Error){
            .type = ERR_TYPE_RUNTIME,
            .value = ERR_WRONG_ARG_TYPE,
        });
        return obj_err;
    }
    Object *body = obj->cdr->cdr;
    Object *fun = alloc_object(OBJ_TYPE_LAMBDA);
    fun->params = obj_duplicate(params);
    fun->body = obj_duplicate(body);
    fun->env = env;
    env_add_or_update_current(env, name, fun);
    return obj_nil;
}

Now we can turn our attention to the lambda evaluation. We have two possible cases: named and anonymous functions. For a function to be called it needs to be the first element of a list. So if we have a symbol as the first element, we look it up on the current environment to hopefully obtain a primitive procedure or a lambda, if we can’t find the symbol or it is neither of those two, we throw an uncallable error. We could also have a list that after the evaluation returns a lambda procedure. So with that in mind, we can update our evaluation of OBJ_TYPE_PAIR as follows:

...
                 if (val->type == OBJ_TYPE_PROCEDURE) {
                     return val->proc(env, root->cdr);
                 }
                 if (val->type == OBJ_TYPE_LAMBDA) {
                     goto eval_lambda;
                 }
                 error_push((Error){
                     .type = ERR_TYPE_RUNTIME,
                     .value = ERR_OBJ_NOT_CALLABLE,
                 });
                 return obj_err;
             }
             Object* lambda;
 eval_lambda:
             lambda = eval(env, root->car);
             if (lambda == obj_err) {
                 return obj_err;
             }
             if (lambda->type == OBJ_TYPE_LAMBDA) {
                 Object *fun = lambda;
                 Object *args = root->cdr;
                 Object *params = fun->params;
                 env = env_extend(fun->env, env);
                 while (params != obj_nil) {
                     if (args == obj_nil) {
                         error_push((Error){
                             .type = ERR_TYPE_RUNTIME,
                             .value = ERR_NOT_ENOUGH_ARGS,
                         });
                         return obj_err;
                     }
                     Object *symbol = params->car;
                     Object *value = eval(env, args->car);
                     if (value == obj_err) {
                         return obj_err;
                     }
                     if (value == obj_nil) {
                         error_push((Error){
                             .type = ERR_TYPE_RUNTIME,
                             .value = ERR_NOT_ENOUGH_ARGS,
                         });
                         return obj_err;
                     }
                     env_add_or_update_current(env, symbol, value);
                     args = args->cdr;
                     params = params->cdr;
                 }
                 if (args != obj_nil) {
                     error_push((Error){
                         .type = ERR_TYPE_RUNTIME,
                         .value = ERR_TOO_MANY_ARGS,
                     });
                     return obj_err;
                 }
                 root = fun->body;
                 while (root->cdr != obj_nil) {
                     if (eval(env, root->car) == obj_err) {
                         return obj_err;
                     };
                     root = root->cdr;
                 }
                 return eval(env, root->car);
             }

Oof that is a lot! First things first, yes we are using a GOTO as a way of deduplicating the code, and this has another purpose. We didn’t think much about this yet, but we should try to ensure that our function calls perform tail-call optimizations (TCO). Not that this is necessarily the case now, but in the future we will test these assumptions and ensure this is the case, and we will probably make use of GOTOs for that as well. In any case, let’s break that down:

        if (val->type == OBJ_TYPE_PROCEDURE) {
            return val->proc(env, root->cdr);
        }
        if (val->type == OBJ_TYPE_LAMBDA) {
            goto eval_lambda;
        }

This just check if the value associated with a symbol is a lambda function and jumps to that evaluation. Note that there is a small bug in this code since:

             Object* lambda;
 eval_lambda:
             lambda = eval(env, root->car);
             if (lambda == obj_err) {
                 return obj_err;
             }

Will perform the evaluation again, meaning that proc_lambda will be called twice when we run a lambda function from within a symbol. This will be addressed in the next update, but let’s keep going for now:

                 Object *fun = lambda;
                 Object *args = root->cdr;
                 Object *params = fun->params;
                 env = env_extend(fun->env, env);

This is the most critical part for creating closures. We “extend” the environment that was registered by our lambda with the symbols from the current environment. What it’s actually happening is that we instantiate a new environment that has the lambda environment as a parent. In this new environment, we copy all the non-existing symbols from the current calling environment.

Environment *
env_extend(Environment *parent, Environment *extra) {
    Environment *env = env_create(parent);
    for (size_t i = 0; i < extra->size; i++) {
        EnvEntry entry = extra->buf[i];
        Environment *tmp = env;
        bool found = false;
        while (tmp != NULL) {
            if (env_index_current(tmp, entry.symbol) != -1) {
                found = true;
                break;
            }
            tmp = tmp->parent;
        }
        if (!found) {
            env_add_symbol(env, obj_duplicate(entry.symbol), obj_duplicate(entry.value));
        }
    }
    return env;
}

There is probably a better way of doing this process, but this seems to be doing the trick for now. Just bear in mind that recursive calls to lambda will keep creating new environments, so this is something we want to address when we work on TCO.

The next part binds the arguments to the formal parameters and checks if we have too many or not enough arguments.

        while (params != obj_nil) {
            if (args == obj_nil) {
                error_push((Error){
                    .type = ERR_TYPE_RUNTIME,
                    .value = ERR_NOT_ENOUGH_ARGS,
                });
                return obj_err;
            }
            Object *symbol = params->car;
            Object *value = eval(env, args->car);
            if (value == obj_err) {
                return obj_err;
            }
            if (value == obj_nil) {
                error_push((Error){
                    .type = ERR_TYPE_RUNTIME,
                    .value = ERR_NOT_ENOUGH_ARGS,
                });
                return obj_err;
            }
            env_add_or_update_current(env, symbol, value);
            args = args->cdr;
            params = params->cdr;
        }
        if (args != obj_nil) {
            error_push((Error){
                .type = ERR_TYPE_RUNTIME,
                .value = ERR_TOO_MANY_ARGS,
            });
            return obj_err;
        }

Finally we switch the execution to the body of the lambda and return the result of evaluating the last expression.

        root = fun->body;
        while (root->cdr != obj_nil) {
            if (eval(env, root->car) == obj_err) {
                return obj_err;
            };
            root = root->cdr;
        }
        return eval(env, root->car);

Conclusion

That is it for today! I’ve spent quite a lot of time trying to make closures work as I wanted, but I’m happy how things turned out. As usual check v0.5 for the current version of the code. I’ve also added a new bunch of tests and expected results, and so running make tests should return successfully.

Our language is now Turing complete and it could start being useful. One big problem however, is that we are generating a lot of memory leaks, so the next area to focus on is to build a garbage collector. See you soon!