Add recursive calls, (function), #'
diff --git a/src/lisp/compiler.dasc b/src/lisp/compiler.dasc
index 8d96a52..9a64801 100644
--- a/src/lisp/compiler.dasc
+++ b/src/lisp/compiler.dasc
@@ -16,6 +16,7 @@
|.arch x86;
|.macro setup, nvars;
+|->function_start:
| push ebp;
| mov ebp, esp;
| sub esp, (value_size * nvars);
@@ -88,7 +89,7 @@
struct dasm_State *compile_function(value_t args, enum namespace namespace,
struct environment *env, struct local *local_out,
- struct local *local_parent, int *nargs)
+ struct local *local_parent, int *nargs, char *name)
{
dasm_State *d;
dasm_State **Dst = &d;
@@ -115,15 +116,15 @@
local.num_stack_entries = 0;
local.num_closure_slots = 0;
local.parent = local_parent;
+ local.current_function_name = name;
dasm_growpc(&d, local.npc);
- // Generate code
- // TODO: first pass, extract bound and free variables
-
value_t arglist = car(args);
value_t body = cdr(args);
+ local.num_args = length(arglist);
+
value_t a = arglist;
for (int i = 0; !nilp(a); a = cdr(a), i++)
{
@@ -181,9 +182,10 @@
struct local local;
int nargs;
- dasm_State *d = compile_function(cdr(args), namespace, env, &local, NULL, &nargs);
+ char *name = (char *)(car(args) ^ SYMBOL_TAG);
+ dasm_State *d = compile_function(cdr(args), namespace, env, &local, NULL, &nargs, name);
- add_function(env, (char *)(car(args) ^ SYMBOL_TAG), link(&d),
+ add_function(env, name, link(&d),
nargs, namespace);
dasm_free(&d);
@@ -394,6 +396,23 @@
compile_backquote(env, local, car(args), Dst);
}
+ else if (symstreq(fsym, "function"))
+ {
+ if (nargs != 1)
+ {
+ err("function should take exactly 1 argument");
+ }
+
+ if (!symbolp(car(args)))
+ {
+ err("argument to function should be a symbol resolvable at compile time");
+ }
+
+ struct function *f = find_function(env, (char *)(car(args) ^ SYMBOL_TAG));
+ value_t closure = create_closure(f->code_ptr, f->nargs, 0);
+
+ | mov eax, (closure);
+ }
else if (symstreq(fsym, "list"))
{
| push (nil);
@@ -420,17 +439,17 @@
// Compile the function with this as the parent scope
struct local new_local;
int nargs_out;
- dasm_State *d = compile_function(args, NS_ANONYMOUS, env, &new_local, local, &nargs_out);
+ dasm_State *d = compile_function(args, NS_ANONYMOUS, env, &new_local, local, &nargs_out, "recurse");
// Link the function
void *func_ptr = link(&d);
// Create a closure object with the correct number of captures at
// runtime
- | mov ebx, (create_closure);
| push (new_local.num_closure_slots);
| push (nargs_out);
| push (func_ptr);
+ | mov ebx, (create_closure);
| call ebx;
| add esp, 12;
@@ -448,9 +467,9 @@
compile_variable(find_variable(local, var->name), Dst);
| push eax;
- | mov ebx, (set_closure_capture_variable);
// The capture offset
| push (var->number);
+ | mov ebx, (set_closure_capture_variable);
| call ebx;
// Skip the value and index
| add esp, 8;
@@ -466,28 +485,52 @@
}
else
{
- struct function *func =
- find_function(env, (char *)(fsym ^ SYMBOL_TAG));
+ char *name = (char *)(fsym ^ SYMBOL_TAG);
+ struct function *func = find_function(env, name);
+
+ bool is_recursive = false;
+ int nargs_needed = 0;
- if (func == NULL)
- err("Function undefined");
-
- if (nargs != func->nargs)
+ if (symstreq(fsym, local->current_function_name))
{
- fprintf(stderr, "Function: %s at %s:%d\n", func->name, cons_file(val), cons_line(val));
+ is_recursive = true;
+ nargs_needed = local->num_args;
+ }
+ else
+ {
+ if (func == NULL)
+ {
+ fprintf(stderr, "Function call: %s at %s:%d\n", name, cons_file(val), cons_line(val));
+ err("Function undefined");
+ }
+
+ nargs_needed = func->nargs;
+ }
+
+ if (nargs != nargs_needed)
+ {
+ fprintf(stderr, "Function call: %s at %s:%d, want %d args but given %d\n",
+ name, cons_file(val), cons_line(val), nargs_needed, nargs);
err("wrong number of args");
}
- if (func->namespace == NS_FUNCTION)
+ if (is_recursive || func->namespace == NS_FUNCTION)
{
for (int i = length(args) - 1; i >= 0; i--)
{
compile_expression(env, local, elt(args, i), Dst);
| push eax;
}
-
- | mov ebx, (func->code_addr);
- | call ebx;
+
+ if (is_recursive)
+ {
+ | call ->function_start;
+ }
+ else
+ {
+ | mov ebx, (func->code_addr);
+ | call ebx;
+ }
| add esp, (nargs * value_size);
// result in eax
}
diff --git a/src/lisp/compiler.h b/src/lisp/compiler.h
index 07928d2..16cc35e 100644
--- a/src/lisp/compiler.h
+++ b/src/lisp/compiler.h
@@ -61,7 +61,11 @@
/// Parent environment, NULL if none (root).
struct local *parent;
- int num_vars;
+ /// Name that the current function should be referred to by, e.g. `recurse`
+ /// for a lambda.
+ char *current_function_name;
+
+ int num_vars, num_args;
/// Most recently defined variable
struct variable *first;
int npc;
@@ -95,7 +99,7 @@
*/
struct dasm_State *compile_function(value_t args, enum namespace namespace,
struct environment *env, struct local *local_out,
- struct local *local_parent, int *nargs);
+ struct local *local_parent, int *nargs, char *name);
void compile_variable(struct variable *v, dasm_State *Dst);
diff --git a/src/lisp/lisp.c b/src/lisp/lisp.c
index b619033..c606e29 100644
--- a/src/lisp/lisp.c
+++ b/src/lisp/lisp.c
@@ -294,16 +294,20 @@
char c = is->peek(is);
- if (c == '\'' || c == '`' || c == ',')
+ if (c == '\'' || c == '`' || c == ',' || c == '#')
{
is->get(is);
- if (c == '`' && is->peek(is) == '@')
+ if (c == ',' && is->peek(is) == '@')
{
// This is actually a splice
is->get(is);
c = '@';
}
+ else if (c == '#' && is->peek(is) == '\'')
+ {
+ is->get(is);
+ }
// Read the next form and wrap it in the appropriate function
@@ -334,6 +338,12 @@
case '@':
symbol = symval("unquote-splice");
break;
+ case '#':
+ symbol = symval("function");
+ break;
+ default:
+ is->showpos(is, stderr);
+ err("Something went wrong parsing a reader macro");
}
*val = cons(symbol, cons(wrapped, nil));
diff --git a/src/lisp/test-closures.lisp b/src/lisp/test-closures.lisp
index 1311a98..573024d 100644
--- a/src/lisp/test-closures.lisp
+++ b/src/lisp/test-closures.lisp
@@ -1,5 +1,12 @@
+(defun mapcar (func list)
+ (if list
+ (cons (apply func (list (car list)))
+ (mapcar func (cdr list)))
+ nil))
+
+(defun double (n)
+ (+ n n))
+
(defun main ()
- (let1 (number 3)
- (let1 (adds-3 (lambda (n)
- (+ n number)))
- (print (apply adds-3 '(4))))))
+ (print (mapcar #'double
+ (list 1 2 3 4 5))))