#include "compiler.hpp"

#include <cassert>
#include <unordered_map>
#include <rovdefs.h>
#include <stack>

#include "bytebuf.hpp"
#include "parser.hpp"

extern std::unordered_map<std::string, int> gReferenceIds;

static ImmType ImmTypeForInteger(int);
static void WriteImmediateValue(ByteBuffer& bytebuf, int value, int type = -1);
static void WriteImmediateFloat(ByteBuffer& bytebuf, float value);
static int GetReferenceId(const std::string& name);

enum class ParamType
{
	None,
	OptionalFloat, // defaults to one if no parameter is provided
	Number, // registers resolve as numbers
	String,
	Entity,
	State,
	Item
};
struct FunctionDefinition
{
	BytecodeOp op;
	const char* name;
	std::vector<ParamType> params;
};
static FunctionDefinition functions[] = {
	{ INSN_ACT_SETSTATE,	"setstate", {ParamType::State} },
	{ INSN_ACT_NEXTSTATE,	"nextstate", {} },
	{ INSN_ACT_STOPMOVING,	"stopmoving", {} },
	{ INSN_ACT_DELETEME,	"deleteme", {} },
	{ INSN_ACT_MOVE,		"move", {ParamType::Number, ParamType::Number, ParamType::OptionalFloat} },
	{ INSN_ACT_MOVERANDOM,	"moverandomly", {ParamType::Number, ParamType::OptionalFloat} },
	{ INSN_ACT_MOVEALLY,	"moveally", {ParamType::Number, ParamType::Number, ParamType::OptionalFloat} },
	{ INSN_ACT_HITSCAN,		"hitscan", {ParamType::Number, ParamType::Number, ParamType::Number} },
	{ INSN_ACT_HIT_SWEEP,	"hit_sweep", {ParamType::Number, ParamType::Number} },
	{ INSN_ACT_HIT_RADIUS,	"hit_radius", {ParamType::Number, ParamType::Number} },
	{ INSN_ACT_SHOOT,		"shoot", {ParamType::Entity, ParamType::Number, ParamType::Number, ParamType::OptionalFloat} },
	{ INSN_ACT_PLAYSOUND,	"playsound", {ParamType::String} },
	{ INSN_ACT_HURTOTHER,	"hurtother", {ParamType::Number} },
	{ INSN_ACT_DELETEOTHER,	"deleteother", {} },
	{ INSN_ACT_THROWOTHER,	"throwother", {ParamType::Number} },
	{ INSN_ACT_STATMODOTHER,"effectother", {ParamType::Number, ParamType::Number, ParamType::Number} },
	{ INSN_ACT_DROPITEM,	"dropitem", {ParamType::Item} },
	{ INSN_ACT_DROPLOOT,	"droploot", {} },
	{ INSN_ACT_HEAL,		"heal", {ParamType::Number} },
	{ INSN_ACT_WARP,		"warp", {ParamType::Number} },
	{ INSN_ACT_WARPTO,		"warpto", {ParamType::Number, ParamType::Number, ParamType::Number} },
	{ INSN_ACT_AWARDXP,		"awardxp", {ParamType::Number} },
	{ INSN_ACT_PARTICLEBURST,"particles", {ParamType::Number, ParamType::Number, ParamType::Number} },
	{ INSN_ACT_ADDSTATMOD,	"addModifier", {ParamType::Number, ParamType::Number, ParamType::Number} },
	{ INSN_ACT_SUMMONBURST,	"summonburst", {ParamType::Entity, ParamType::Number, ParamType::Number, ParamType::Number} },
	{ INSN_ACT_SAY,			"say", {ParamType::String} },
	{ INSN_CND_CHANCE,		"chance", {ParamType::Number} },
	{ INSN_CND_SEESENEMY,	"seesenemy", {} },
	{ INSN_CND_ENEMYWITHIN, "enemywithin", {ParamType::Number} },
	{ INSN_CND_SEESALLY,	"seesally", {} },
	{ INSN_CND_ALLYWITHIN,	"allywithin", {ParamType::Number} },
	{ INSN_CND_STATEIS,		"stateis", {ParamType::State} },
	{ INSN_CND_OTHERIS,		"otheris", {ParamType::Entity} },
};
constexpr int NUM_FUNCTIONS = sizeof(functions) / sizeof(functions[0]);

const FunctionDefinition* GetFunctionByName(const std::string& name)
{
	for (int i = 0; i < NUM_FUNCTIONS; i++)
	{
		if (functions[i].name == name)
			return &functions[i];
	}
	return nullptr;
}

Register GetRegisterByName(const std::string& name)
{
	for (int i = 0; i < NUM_REGISTERS; i++)
	{
		if (name == registerNames[i])
			return static_cast<Register>(i);
	}
	return NUM_REGISTERS; // An EPIC FAIL has occurred
}

static void EmitSequence(const ASTNode&, ByteBuffer&);
static void EmitExpression(const ASTNode&, ByteBuffer&);
static void EmitFunction(const ASTNode&, ByteBuffer&);

static std::stack<int> labelPatches;
static std::vector<std::string> stringTable;

void CompileScript(const ASTNode& root, ByteBuffer& bytebuf)
{
	assert(root.type == ASTNODE_SEQ);
	EmitSequence(root, bytebuf);
	bytebuf.WriteUInt8(INSN_LOGI_ENDCODE);
}

void WriteStringTable(ByteBuffer& bytebuf)
{
	int tableOffs = bytebuf.size();
	bytebuf.WriteUInt32(stringTable.size());
	for (const auto& str : stringTable)
	{
		bytebuf.WriteBytes((void*)str.c_str(), str.length() + 1);
	}
	bytebuf.WriteUInt32(tableOffs);
}

static void EmitSequence(const ASTNode& root, ByteBuffer& bytebuf)
{
	for (const auto& node : root.subNodes) {
		switch (node.type)
		{
		case ASTNODE_IFELSE:
			EmitExpression(node.subNodes[0], bytebuf); // condition
			// When the previous condition is true, execution should fall through to following code.
			// Otherwise, it should jump to either the associated else or endif statement.
			// Since the location isn't known at this point, a dummy value is written and is added to a
			// stack to be overwritten once the else or endif is found.

			bytebuf.WriteUInt8(INSN_LOGI_NOT); // Invert the value, since we only want to jump if the statement is false
			bytebuf.WriteUInt8(INSN_LOGI_JMPIF);
			// Write a dummy value to be patched once the else/endif offset is known.
			labelPatches.push(bytebuf.size() + 1);
			WriteImmediateValue(bytebuf, 0xABCDEF00, IMMTYPE_U32);

			EmitSequence(node.subNodes[1], bytebuf);

			if (node.subNodes.size() > 2 // If there is an 'else' branch
				&& node.subNodes[2].subNodes.size() > 0) // Quick fix: Ignore empty 'else' blocks. TODO: The tree should be optimized beforehand, then this could be removed.
			{
				// If the condition was true and execution reached the truthy branch, there needs to be
				// a jump to the corresponding endif to skip the false branch.
				bytebuf.WriteUInt8(INSN_LOGI_JMP);

				int patchPos = labelPatches.top();
				labelPatches.pop();

				// Write a dummy value to be patched once the endif offset is known.
				labelPatches.push(bytebuf.size() + 1);
				WriteImmediateValue(bytebuf, 0xABCDEF00, IMMTYPE_U32);

				bytebuf.OverwriteUInt32(patchPos, bytebuf.size());

				// Now that the jump has been set up, write the sequence for the false branch
				EmitSequence(node.subNodes[2], bytebuf);
			}

			bytebuf.OverwriteUInt32(labelPatches.top(), bytebuf.size());
			labelPatches.pop();
			break;
		case ASTNODE_FUNC:
			EmitFunction(node, bytebuf);
			break;
		case ASTNODE_ASSIGNMENT: {
			Register reg = GetRegisterByName(node.subNodes[0].value.contents.c_str());
			if (reg == NUM_REGISTERS) {
				PrintErrorMsg(node.value.filepos, ERR_LVL_ERROR, "Assignment to unknown register");
				return;
			}

			EmitExpression(node.subNodes[1], bytebuf);

			if (node.value.type != TOK_ASSIGN) {
				// Add- or subtract-assignment
				bytebuf.WriteUInt8(INSN_CORE_GET);
				WriteImmediateValue(bytebuf, reg);
				if (node.value.type == TOK_PLUSEQU)
					bytebuf.WriteUInt8(INSN_CORE_ADD);
				if (node.value.type == TOK_MINUSEQU)
					bytebuf.WriteUInt8(INSN_CORE_SUB);
			}

			bytebuf.WriteUInt8(INSN_CORE_SET);
			WriteImmediateValue(bytebuf, reg);
			break;
		}
		default:
			break;
		}
	}
}

static void EmitExpression(const ASTNode& node, ByteBuffer& bytebuf)
{
	switch (node.type) {
	case ASTNODE_UNARYOP:
		if (node.value.type == TOK_NAME && node.value.contents == "not")
		{
			EmitExpression(node.subNodes[0], bytebuf);
			bytebuf.WriteUInt8(INSN_LOGI_NOT);
		}
		else if (node.value.type == TOK_MINUS)
		{
			// TODO: Add negation instruction to bytecode
			PrintErrorMsg(node.value.filepos, ERR_LVL_ERROR, "Minus operator unimplemented");
		}
		break;
	case ASTNODE_BINARYOP:
		EmitExpression(node.subNodes[1], bytebuf);
		EmitExpression(node.subNodes[0], bytebuf);
		switch (node.value.type)
		{
		case TOK_PLUS:
			bytebuf.WriteUInt8(INSN_CORE_ADD);
			break;
		case TOK_MINUS:
			bytebuf.WriteUInt8(INSN_CORE_ADD);
			break;
		case TOK_COMPARE:
			bytebuf.WriteUInt8(INSN_CORE_EQUALS);
			break;
		case TOK_LESSTHAN:
			bytebuf.WriteUInt8(INSN_CORE_LESS);
			break;
		case TOK_GREATERTHAN:
			bytebuf.WriteUInt8(INSN_CORE_GREATER);
			break;
		case TOK_LESSEQU:
			bytebuf.WriteUInt8(INSN_CORE_LESSEQU);
			break;
		case TOK_GREATEREQU:
			bytebuf.WriteUInt8(INSN_CORE_GREATEREQU);
			break;
		case TOK_NAME:
			if (node.value.contents == "and")
			{
				bytebuf.WriteUInt8(INSN_LOGI_AND);
			}
			else if (node.value.contents == "or")
			{
				bytebuf.WriteUInt8(INSN_LOGI_OR);
			}
			break;
		}
		break;
	case ASTNODE_FUNC:
		EmitFunction(node, bytebuf);
		break;
	case ASTNODE_LITERAL:
		switch (node.value.type) {
		case TOK_NUMBER:
			bytebuf.WriteUInt8(INSN_CORE_PUSH);
			if (node.value.contents.find(".") != std::string::npos)
				WriteImmediateFloat(bytebuf, std::stof(node.value.contents));
			else
				WriteImmediateValue(bytebuf, std::stoi(node.value.contents, nullptr, 0));
			break;
		case TOK_STRING:
			// TODO: strings are immediate
			stringTable.push_back(node.value.contents);
			bytebuf.WriteUInt8(INSN_CORE_PUSH);
			WriteImmediateValue(bytebuf, stringTable.size() - 1);
			break;
		case TOK_REFERENCE:
			int id = GetReferenceId(node.value.contents);
			if (id == -1)
			{
				PrintErrorMsg(node.value.filepos, ERR_LVL_ERROR, "Unknown reference '%s'", node.value.contents.c_str());
			}
			bytebuf.WriteUInt8(INSN_CORE_PUSH);
			WriteImmediateValue(bytebuf, id);
			break;
		}
		break;
	case ASTNODE_REGISTER: {
		Register reg = GetRegisterByName(node.value.contents);
		if (reg == NUM_REGISTERS) {
			PrintErrorMsg(node.value.filepos, ERR_LVL_ERROR, "Unknown register '%s'", node.value.contents.c_str());
			return;
		}
		bytebuf.WriteUInt8(INSN_CORE_GET);
		WriteImmediateValue(bytebuf, reg);
		break;
	}
	default:
		assert(false);
		break;
	}
}

static void EmitFunction(const ASTNode& node, ByteBuffer& bytebuf)
{
	const FunctionDefinition* func = GetFunctionByName(node.value.contents);
	if (func == nullptr) {
		PrintErrorMsg(node.value.filepos, ERR_LVL_ERROR, "Unknown function '%s'", node.value.contents.c_str());
		return;
	}

	int numParamsNeeded = func->params.size();
	int numParamsGiven = node.subNodes.size();
	bool hasOptionalParameter = false;
	if (numParamsGiven != numParamsNeeded) {
		if (func->params.back() == ParamType::OptionalFloat && numParamsGiven == numParamsNeeded - 1) {
			hasOptionalParameter = true;
		}
		else {
			PrintErrorMsg(node.value.filepos, ERR_LVL_ERROR, "Expected %d parameters, but got %d", numParamsNeeded, numParamsGiven);
			return;
		}
	}

	// Iterate function parameters in reverse order so they
	// can be popped by the interpreter from first to last.
	if (hasOptionalParameter) {
		float value = func->op == INSN_ACT_SHOOT ? 0.0f : 1.0f;
		bytebuf.WriteUInt8(INSN_CORE_PUSH);
		bytebuf.WriteUInt8(IMMTYPE_FLOAT32);
		bytebuf.WriteFloat(value);
	}

	bool hasImmediate = false;
	int immediate = -1;

	for (int i = numParamsGiven; i > 0; i--) {
		const ASTNode& subNode = node.subNodes[i - 1];
		if (subNode.type == ASTNODE_LITERAL && i == 1) {
			// Special handling for strings and references, which are immediate values when in the first parameter
			if (subNode.value.type == TOK_REFERENCE)
			{
				hasImmediate = true;
				immediate = GetReferenceId(subNode.value.contents);
				if (immediate == -1)
				{
					PrintErrorMsg(subNode.value.filepos, ERR_LVL_ERROR, "Unknown reference '%s'", subNode.value.contents.c_str());
				}
				continue;
			}
			else if (subNode.value.type == TOK_STRING)
			{
				hasImmediate = true;
				immediate = stringTable.size();
				stringTable.emplace_back(subNode.value.contents);
				continue;
			}
			// TODO: Check that type matches expected function parameter type
		}
		EmitExpression(subNode, bytebuf);
	}

	bytebuf.WriteUInt8(func->op);
	if (hasImmediate) {
		WriteImmediateValue(bytebuf, immediate);
	}
}

static ImmType ImmTypeForInteger(int v)
{
	if (v < 0)
	{
		v = -v;
		if (v < INT8_MAX)
			return IMMTYPE_S8;
		if (v < INT16_MAX)
			return IMMTYPE_S16;
		return IMMTYPE_S32;
	}

	if (v < UINT8_MAX)
		return IMMTYPE_U8;
	if (v < UINT16_MAX)
		return IMMTYPE_U16;
	// NOTE: Integers greater than UINT32_MAX will be truncated
	return IMMTYPE_U32;
}

static void WriteImmediateValue(ByteBuffer& bytebuf, int value, int type)
{
	// Choose a default if no data type was specified
	if (type == -1)
		type = ImmTypeForInteger(value);

	bytebuf.WriteUInt8(type);

	switch (type)
	{
	case IMMTYPE_U8:
		bytebuf.WriteUInt8(value);
		break;
	case IMMTYPE_S8:
		bytebuf.WriteInt8(value);
		break;
	case IMMTYPE_U16:
		bytebuf.WriteUInt16(value);
		break;
	case IMMTYPE_S16:
		bytebuf.WriteInt16(value);
		break;
	case IMMTYPE_U32:
		bytebuf.WriteUInt32(value);
		break;
	case IMMTYPE_S32:
		bytebuf.WriteInt32(value);
		break;
	default:
		throw std::runtime_error("Attempted to write non-integer data type as integer.");
	}
}

static void WriteImmediateFloat(ByteBuffer& bytebuf, float value)
{
	bytebuf.WriteUInt8(IMMTYPE_FLOAT32);
	bytebuf.WriteFloat(value);
}

static int GetReferenceId(const std::string& name)
{
	if (gReferenceIds.find(name) == gReferenceIds.end())
	{
		return -1;
	}
	return gReferenceIds[name];
}