/* -*- mode:c -*- */

#include "compiler.h"
#include "lib/std.h"
#include "plat/plat.h"

#include <dasm_proto.h>
#include <dasm_x86.h>

#include <stdlib.h>
#include <string.h>

#define value_size sizeof(value_t)

|.arch x86;

|.macro setup, nvars;
| push ebp;
| mov ebp, esp;
| sub esp, (value_size * nvars);
|.endmacro;

|.macro cleanup;
| mov esp, ebp;
| pop ebp;
| ret;
|.endmacro;

|.macro local_var, index;
|.endmacro;

dasm_State *d;
unsigned int npc = 8;

extern void _do_gc(unsigned int ebp, unsigned int esp);

|.macro run_gc;
| mov eax, esp;
| push ebp;
| push eax;
| mov eax, _do_gc;
| call eax;
|.endmacro;

struct function *find_function(struct environment *env, char *name)
{
	struct function *f = env->first;

	while (strcmp(f->name, name) != 0)
	{
		if (f->prev)
			f = f->prev;
		else
			return NULL;
	}

	return f;
}

unsigned int local_alloc(struct local *local)
{
	for (int i = 0; i < local->num_stack_slots; i++)
	{
		if (local->stack_slots[i] == false)
		{
			local->stack_slots[i] = true;

			if (i >= local->num_stack_entries)
				local->num_stack_entries++;

			return i;
		}
	}

	int old_size = local->num_stack_slots;
	local->num_stack_slots += 4;
	local->stack_slots = realloc(local->stack_slots, local->num_stack_slots * sizeof(bool));
	// unreadable: set the remaining slots to unused
	memset(local->stack_slots + old_size, 0, local->num_stack_slots - old_size);
	local->stack_slots[old_size] = true;

	return old_size;
}

void local_free(struct local *local, unsigned int slot)
{
	local->stack_slots[slot] = false;
}

void compile_tl(value_t val, struct environment *env)
{
	if (!listp(val))
		err("Top level must be a list");

	value_t form = car(val);
	value_t args = cdr(val);

	if (symstreq(form, "defun"))
	{
		dasm_State *d;
		dasm_State **Dst = &d;

		|.section code;
		dasm_init(&d, DASM_MAXSECTION);

		|.globals lbl_;
		void *labels[lbl__MAX];
		dasm_setupglobal(&d, labels, lbl__MAX);

		|.actionlist lisp_actions;
		dasm_setup(&d, lisp_actions);

		struct local local;
		local.first = NULL;
		local.num_vars = 0;
		local.npc = 8;
		local.nextpc = 0;
		local.stack_slots = malloc(sizeof(bool) * 4);
		memset(local.stack_slots, 0, sizeof(bool) * 4);
		local.num_stack_slots = 4;
		local.num_stack_entries = 0;

		dasm_growpc(&d, local.npc);

		// Generate code
		// TODO: first pass, extract bound and free variables

		value_t name = car(args);
		args = cdr(args);
		value_t arglist = car(args);
		value_t body = cdr(args);

		if ((name & HEAP_MASK) != SYMBOL_TAG)
			err("function name must be a symbol");

		value_t a = arglist;
		for (int i = 0; !nilp(a); a = cdr(a), i++)
		{
			if (!symbolp(car(a)))
			{
				err("defun argument must be a symbol");
			}

			add_variable(&local, V_ARGUMENT, (char *)(car(a) ^ SYMBOL_TAG), i);
		}

		for (value_t body_ = body; !nilp(body_); body_ = cdr(body_))
		{
			walk_and_alloc(&local, car(body_));
		}

		| setup (local.num_stack_entries);

		memset(local.stack_slots, 0, local.num_stack_slots * sizeof(bool));
		local.num_stack_entries = 0;

		for (; !nilp(body); body = cdr(body))
		{
			compile_expression(env, &local, car(body), Dst);
		}

		| cleanup;

		add_function(env, (char *)(name ^ SYMBOL_TAG), link(Dst),
		             length(arglist));

		dasm_free(&d);
		free(local.stack_slots);
	}
}

void walk_and_alloc(struct local *local, value_t body)
{
	if (!listp(body))
		return;

	value_t args = cdr(body);

	if (symstreq(car(body), "let1"))
	{
		int slot = local_alloc(local);

		value_t expr = cdr(args);

		local_free(local, slot);
	}
	else
	{
		for (; !nilp(args); args = cdr(args))
		{
			walk_and_alloc(local, car(args));
		}
	}
}

struct environment compile_all(struct istream *is)
{
	value_t val;
	struct environment env;
	env.first = NULL;
	load_std(&env);

	while (read1(is, &val))
	{
		compile_tl(val, &env);
	}

	return env;
}

int nextpc(struct local *local, dasm_State **Dst)
{
	int n = local->nextpc++;
	if (n > local->npc)
	{
		local->npc += 16;
		dasm_growpc(Dst, local->npc);
	}
	return n;
}

void compile_expression(struct environment *env, struct local *local,
                        value_t val, dasm_State **Dst)
{
	if (symstreq(val, "nil"))
	{
		| mov eax, (nil);
	}
	else if (symstreq(val, "t"))
	{
		| mov eax, (t);
	}
	else if (integerp(val) || stringp(val))
	{
		| mov eax, val;
	}
	else if (listp(val))
	{
		value_t fsym = car(val);
		value_t args = cdr(val);
		int nargs = length(args);

		if (!symbolp(fsym))
		{
			err("function name must be a symbol");
		}

		if (symstreq(fsym, "if"))
		{
			if (nargs < 2 || nargs > 3)
				err("Must give at least 2 arguments to if");

			compile_expression(env, local, car(args), Dst);
			int false_label = nextpc(local, Dst),
			    after_label = nextpc(local, Dst);

			// result is in eax
			| cmp eax, (nil);
			| je =>false_label;

			compile_expression(env, local, elt(args, 1), Dst);
			| jmp =>after_label;
			|=>false_label:;
			if (nargs == 3)
			    compile_expression(env, local, elt(args, 2), Dst);
			|=>after_label:
		}
		else if (symstreq(fsym, "let1"))
		{
			if (nargs < 2)
			{
				err("Must give at least 2 arguments to let1");
			}
			value_t binding = car(args);
			value_t rest = cdr(args);

			if (length(binding) != 2)
			{
				err("Binding list in let1 must contain exactly two entries");
			}

			value_t name = car(binding);
			value_t value = car(cdr(binding));

			compile_expression(env, local, value, Dst);

			int i = local_alloc(local);

			add_variable(local, V_BOUND, (char *)(name ^ SYMBOL_TAG), i);

			| mov dword [ebp - ((i + 1) * value_size)], eax;

			for (; !nilp(rest); rest = cdr(rest))
			{
				compile_expression(env, local, car(rest), Dst);
			}

			local_free(local, i);
		}
		else if (symstreq(fsym, "gc"))
		{
			if (nargs)
			{
				err("gc takes no arguments");
			}

			| run_gc;
		}
		else
		{
			struct function *func =
			    find_function(env, (char *)(fsym ^ SYMBOL_TAG));

			if (func == NULL)
				err("Function undefined");

			if (nargs != func->nargs)
				err("wrong number of args");

			for (int i = length(args) - 1; i >= 0; i--)
			{
				compile_expression(env, local, elt(args, i), Dst);
				| push eax;
			}

			| mov ebx, (func->code_addr);
			| call ebx;
			| add esp, (nargs * value_size);
			// result in eax
		}
	}
	else if (symbolp(val))
	{
		// For now ignore global variables, only search locally
		struct variable *v = find_variable(local, (char *)(val ^ SYMBOL_TAG));

		if (!v)
		{
			fprintf(stderr, "var: %s\n", (char *)(val ^ SYMBOL_TAG));
			err("Variable unbound");
		}

		switch (v->type)
		{
		case V_ARGUMENT:
			| mov eax, dword [ebp + (value_size * (v->number + 2))];
			break;
		case V_BOUND:
			| mov eax, dword [ebp - ((v->number + 1) * value_size)];
			break;
		default:
			err("Sorry, can only access V_ARGUMENT and V_BOUND variables for now :(");
		}
	}
}

void compile_expr_to_func(struct environment *env, char *name, value_t val,
                          dasm_State **Dst)
{
	| setup 0;

	struct local local;
	compile_expression(env, &local, val, Dst);

	| cleanup;

	add_function(env, name, link(Dst), 0);
}

struct variable *add_variable(struct local *local, enum var_type type,
                              char *name, int number)
{
	struct variable *var = malloc(sizeof(struct variable));
	var->prev = local->first;
	var->type = type;
	var->name = name;
	var->number = number;

	local->first = var;

	return var;
}

void destroy_local(struct local *local)
{
	for (struct variable *v = local->first; v;)
	{
		struct variable *t = v;
		v = v->prev;
		free(t);
	}
}

struct variable *find_variable(struct local *local, char *name)
{
	struct variable *v = local->first;

	for (; v && strcmp(v->name, name) != 0; v = v->prev)
	{}

	return v;
}
