from enum import Enum
from pathlib import Path
from typing import Optional, Union
from dataclasses import dataclass, field


class CType(Enum):
    VOID = "void"
    BOOL = "bool"
    CHAR = "char"
    C8   = "c8"
    C16  = "c16"
    C32  = "c32"
    I8   = "i8"
    I16  = "i16"
    I32  = "i32"
    I64  = "i64"
    U8   = "u8"
    U16  = "u16"
    U32  = "u32"
    U64  = "u64"
    F32  = "f32"
    F64  = "f64"
    F128 = "f128"
    IPTR = "iptr"
    UPTR = "uptr"

    def __str__(self) -> str:
        return self.value


class CQualifier(Enum):
    NONE       = ""
    CONST      = "const "
    EXTERNAL   = "external "
    INTERNAL   = "internal "
    PERSISTENT = "persistent "

    def __str__(self) -> str:
        return self.value


class CPointerType(Enum):
    NONE   = ""
    SINGLE = "*"
    DOUBLE  = "**"

    def __str__(self) -> str:
        return self.value


@dataclass
class CPointer:
    _type: CPointerType = CPointerType.NONE
    qualifier: CQualifier = CQualifier.NONE

    def __str__(self) -> str:
        return str(self._type) + str(self.qualifier)


@dataclass
class CEnumVal:
    name: str
    value: Optional[int] = None

    def __str__(self) -> str:
        return self.name + "" if self.value is None else f" = {self.value}"


@dataclass
class CEnum:
    name: str
    values: list[CEnumVal]
    typedef: bool = False

    def __str__(self) -> str:
        if self.typedef:
            header = "typedef enum {\n"
            footer = f"}} {self.name};\n"
        else:
            header = f"enum {self.name} {{\n"
            footer = "};\n"

        values = ""
        for value in self.values:
            values += f"  {str(value)},\n"

        return header + values + footer


@dataclass
class CStruct:
    name: str
    cargs: list["CArg"]
    typedef_name: str | None = None

    def __str__(self) -> str:
        return self.declare() + self.define()

    def declare(self) -> str:
        declaration = f"typedef struct {self.name} {self.typedef_name if self.typedef_name is not None else self.name};\n"
        return declaration

    def define(self):
        definition = f"struct {self.name} {{\n"
        args = ""
        for arg in self.cargs:
            args += f"  {str(arg)};\n"
        footer = "};\n"

        return definition + args + footer;


CUserType = Union[CStruct, CEnum]
CDataType = Union[CType, CUserType, str]


@dataclass
class CArg:
    name: str
    _type: CDataType
    array: bool = False
    pointer: CPointer = field(default_factory=CPointer)
    qualifier: CQualifier = CQualifier.NONE

    def __str__(self) -> str:
        qualifier = str(self.qualifier)
        _type = get_datatype_string(self._type) + " "
        pointer = str(self.pointer)
        array = "[]" if self.array else ""

        return qualifier + _type + pointer + self.name + array


@dataclass
class CFunc:
    name: str
    ret_type: CDataType
    args: list[CArg]
    body: str
    pointer: CPointer = field(default_factory=CPointer)
    qualifiers: list[CQualifier] = field(default_factory=list)

    def __str__(self) -> str:
        qualifiers = ""
        for qualifier in self.qualifiers:
            if qualifier == CQualifier.NONE:
                continue
            if len(qualifiers) > 0:
                qualifiers += " "
            qualifiers += f"{str(qualifier)}"

        args = ""
        for i, arg in enumerate(self.args):
            args += f"{str(arg)}"
            if i + 1 < len(self.args):
                args += ", "

        return qualifiers + get_datatype_string(self.ret_type) + " " + str(self.pointer) + self.name + f"({args})"

    def declare(self) -> str:
        return f"{str(self)};\n"

    def define(self) -> str:
        return f"{str(self)} {{\n{self.body}\n}}\n\n"


@dataclass
class CInclude:
    header: Union[str, "CHeader"]
    local: bool = False
    same_dir: bool = False

    def __str__(self) -> str:
        if isinstance(self.header, CHeader):
            name = f"{self.header.name}.{self.header.extension}"
        else:
            name = self.header

        if self.local:
            open_symbol = '"'
            close_symbol = '"'

            if self.same_dir:
                name = f"./{name}"
        else:
            open_symbol = '<'
            close_symbol = '>'
        
        return f"#include {open_symbol}{name}{close_symbol}\n"


@dataclass
class CFile:
    name: str
    extension: str
    decl_types: list[CStruct] = field(default_factory=list)

    def save(self, output_dir: Path):
        output_file = output_dir / f"{self.name}.{self.extension}"
        with open(output_file, "w+") as outfile:
            outfile.write(str(self))

    def __str__(self) -> str:
        return """\
/**
 * THIS FILE IS AUTOMATICALLY GENERATED. ANY MODIFICATIONS TO IT WILL BE OVERWRITTEN
 */

"""


@dataclass
class CHeader(CFile):
    extension: str = "h"
    includes: list[CInclude] = field(default_factory=list)
    types: list[CUserType] = field(default_factory=list)
    funcs: list[CFunc] = field(default_factory=list)

    def __str__(self) -> str:
        name_upper = self.name.upper()
        header_guard_name = f"{name_upper}_H"
        header_guard_open = f"#ifndef {header_guard_name}\n#define {header_guard_name}\n\n"
        header_guard_close = f"#endif // !{header_guard_name}\n"

        c_linkage_open = "#ifdef __cplusplus\nBEGIN_C_LINKAGE\n#endif // !__cplusplus\n\n"
        c_linkage_close = "\n#ifdef __cplusplus\nEND_C_LINKAGE\n#endif // !__cplusplus\n\n"

        includes = _get_includes_string(self.includes)

        forward_declarations = ""
        for _type in self.decl_types:
            forward_declarations += _type.declare()
        if len(forward_declarations) > 0:
            forward_declarations += "\n"

        types = ""
        for _type in self.types:
            types += str(_type) + "\n"

        funcs = ""
        for func in self.funcs:
            funcs += func.declare()

        return (
            super().__str__() +
            header_guard_open +
            includes +
            c_linkage_open +
            forward_declarations +
            types +
            funcs +
            c_linkage_close +
            header_guard_close
        )


@dataclass
class CSource(CFile):
    extension: str = "c"
    includes: list[CInclude] = field(default_factory=list)
    types: list[CUserType] = field(default_factory=list)
    internal_funcs: list[CFunc] = field(default_factory=list)
    funcs: list[CFunc] = field(default_factory=list)

    def __str__(self) -> str:
        includes = _get_includes_string(self.includes)

        forward_declarations = ""
        for _type in self.decl_types:
            forward_declarations += _type.declare()
        if len(forward_declarations) > 0:
            forward_declarations += "\n"

        types = ""
        for _type in self.types:
            types += str(_type) + "\n"

        internal_funcs_decl = ""
        internal_funcs_def = ""
        for func in self.internal_funcs:
            internal_funcs_decl += func.declare()
            internal_funcs_def += func.define()

        if len(internal_funcs_decl) > 0:
            internal_funcs_decl += "\n"

        funcs = ""
        for func in self.funcs:
            funcs += func.define()

        return (
            super().__str__() +
            includes +
            forward_declarations +
            types +
            internal_funcs_decl +
            funcs +
            internal_funcs_def
        )


def get_datatype_string(_type: CDataType) -> str:
    if isinstance(_type, CType):
        return str(_type)
    elif isinstance(_type, CStruct) or isinstance(_type, CEnum):
        return _type.name
    elif isinstance(_type, str):
        return _type


def _get_includes_string(includes: list[CInclude]) -> str:
    output = ""
    for include in sorted(includes, key=lambda inc: inc.local, reverse=True):
        output += str(include)
    if len(output) > 0:
        output += "\n"

    return output