from enum import Enum
from pathlib import Path
from typing import Optional, Union
from dataclasses import dataclass, field


class CType(Enum):
    VOID = "void "
    BOOL = "bool "
    CHAR = "char "
    C8   = "c8 "
    C16  = "c16 "
    C32  = "c32 "
    I8   = "i8 "
    I16  = "i16 "
    I32  = "i32 "
    I64  = "i64 "
    U8   = "u8 "
    U16  = "u16 "
    U32  = "u32 "
    U64  = "u64 "
    F32  = "f32 "
    F64  = "f64 "
    F128 = "f128 "
    IPTR = "iptr "
    UPTR = "uptr "

    def __str__(self) -> str:
        return self.value


class CQualifier(Enum):
    NONE       = ""
    CONST      = "const "
    EXTERNAL   = "external "
    INTERNAL   = "internal "
    PERSISTENT = "persistent "

    def __str__(self) -> str:
        return self.value


class CPointerType(Enum):
    NONE   = ""
    SINGLE = "*"
    DOUBLE  = "**"

    def __str__(self) -> str:
        return self.value


@dataclass
class CPointer:
    _type: CPointerType = CPointerType.NONE
    qualifier: CQualifier = CQualifier.NONE

    def __str__(self) -> str:
        return str(self._type) + str(self.qualifier)


@dataclass
class CEnumVal:
    name: str
    value: Optional[int] = None

    def __str__(self) -> str:
        return self.name + "" if self.value is None else f" = {self.value}"


@dataclass
class CEnum:
    name: str
    values: list[CEnumVal]
    typedef: bool = False


@dataclass
class CStruct:
    name: str
    cargs: list["CArg"]


CUserType = Union[CStruct, CEnum]
CDataType = Union[CType, CUserType]

@dataclass
class CArg:
    name: str
    _type: CDataType
    array: bool = False
    pointer: CPointer = field(default_factory=CPointer)
    qualifier: CQualifier = CQualifier.NONE

    def __str__(self) -> str:
        qualifier = str(self.qualifier)
        _type = get_arg_type_string(self._type)
        pointer = str(self.pointer)
        array = "[]" if self.array else ""

        return qualifier + _type + pointer + self.name + array


@dataclass
class CFunc:
    name: str
    ret_type: CDataType
    args: list[CArg]
    body: str
    pointer: CPointer = field(default_factory=CPointer)
    qualifiers: list[CQualifier] = field(default_factory=list)

    def __str__(self) -> str:
        qualifiers = ""
        for qualifier in qualifiers:
            if qualifier == CQualifier.NONE:
                continue
            qualifiers += f"{str(qualifier)} "

        args = ""
        for i, arg in enumerate(self.args):
            args += f"{str(arg)}"
            if i + 1 < len(self.args):
                args += ", "

        return qualifiers + str(self.ret_type) + str(self.pointer) + self.name + f"({args})"

    def declare(self) -> str:
        return f"{str(self)};\n"

    def define(self) -> str:
        return f"{str(self)} {{\n{self.body}\n}}\n"


@dataclass
class CInclude:
    header: Union[str, "CHeader"]
    local: bool = False


@dataclass
class CFile:
    name: str

    def save(self, output_dir: Path):
        output_file = output_dir / self.name
        with open(output_file, "w+") as outfile:
            outfile.write(str(self))

    def __str__(self) -> str:
        return ""


@dataclass
class CHeader(CFile):
    includes: list[CInclude]
    types: list[CUserType]
    funcs: list[CFunc]

    def __str__(self) -> str:
        pragma = "#pragma once\n\n"

        includes = get_includes_string(self.includes)

        types = ""
        for _type in self.types:
            types += get_user_type_string(_type)

        funcs = ""
        for func in self.funcs:
            funcs += func.declare()

        return pragma + includes + types + funcs


@dataclass
class CSource(CFile):
    includes: list[CInclude]
    types: list[CUserType]
    internal_funcs: list[CFunc]
    funcs: list[CFunc]

    def __str__(self) -> str:
        includes = get_includes_string(self.includes)

        types = ""
        for _type in self.types:
            types += get_user_type_string(_type)

        internal_funcs_decl = ""
        internal_funcs_def = ""
        for func in self.internal_funcs:
            internal_funcs_decl += func.declare()
            internal_funcs_def += func.define()

        funcs = ""
        for func in self.funcs:
            funcs += func.define()

        return includes + types + internal_funcs_decl + funcs + internal_funcs_def


def main():
    struct = CStruct(
        name="Str8",
        cargs=[
            CArg(name="size", _type=CType.U64),
            CArg(name="capacity", _type=CType.U64),
            CArg(name="buf", _type=CType.U8, pointer=CPointer(_type=CPointerType.SINGLE)),
        ],
    )

    cenum = CEnum(
        name="OS",
        values=[
            CEnumVal(name="OS_LINUX"),
            CEnumVal(name="OS_WINDOWS"),
            CEnumVal(name="OS_MACOS"),
        ]
    )

    typed_enum = CEnum(
        name="Compiler",
        values=[
            CEnumVal(name="COMPILER_GCC"),
            CEnumVal(name="COMPILER_CLANG"),
            CEnumVal(name="COMPILER_MSVC"),
        ],
        typedef=True,
    )

    main_func = CFunc(
        name="my_custom_func",
        ret_type=CType.I32,
        args=[
            CArg(name="argc", _type=CType.I32),
            CArg(name="argv", _type=CType.CHAR, pointer=CPointer(_type=CPointerType.DOUBLE), qualifier=CQualifier.CONST),
        ],
        body="  return 0;"
    )

    header = CHeader(
        name="str.h",
        includes=[
            CInclude(header="aliases.h", local=True),
        ],
        types=[struct, cenum, typed_enum],
        funcs=[main_func],
    )

    source = CSource(
        name="str.c",
        includes=[
            CInclude(header=header, local=True),
            CInclude(header="aliases.h", local=True),
        ],
        types=[],
        internal_funcs=[],
        funcs=[main_func]
    )

    header.save(Path("."))
    source.save(Path("."))


def get_user_type_string(_type: Union[CStruct, CEnum]) -> str:
    type_str = ""
    if isinstance(_type, CStruct):
        type_str += cstruct_to_string(_type) + "\n"
    else:
        type_str += cenum_to_string(_type) + "\n"

    return type_str


def cenum_to_string(cenum: CEnum) -> str:
    if cenum.typedef:
        header = "typedef enum {\n"
        footer = f"}} {cenum.name};\n"
    else:
        header = f"enum {cenum.name} {{\n"
        footer = "};\n"

    values = ""
    for value in cenum.values:
        values += f"  {str(value)},\n"

    return header + values + footer


def cstruct_to_string(cstruct: CStruct) -> str:
    typedef = f"typedef struct {cstruct.name} {cstruct.name};\n"
    header = f"struct {cstruct.name} {{\n"
    args = ""
    for arg in cstruct.cargs:
        args += f"  {str(arg)};\n"
    footer = "};\n"

    return typedef + header + args + footer;


def get_arg_type_string(_type: CDataType) -> str:
    if isinstance(_type, CType):
        return str(_type)
    elif isinstance(_type, CStruct) or isinstance(_type, CEnum):
        return _type.name


def get_includes_string(includes: list[CInclude]) -> str:
    output = ""
    for include in includes:
        if isinstance(include.header, CHeader):
            name = include.header.name
        else:
            name = include.header

        if include.local:
            open_symbol = '"'
            close_symbol = '"'
        else:
            open_symbol = '<'
            close_symbol = '>'
        
        output += f"#include {open_symbol}{name}{close_symbol}\n"
    if len(output) > 0:
        output += "\n"

    return output


if __name__ == "__main__":
    main()