wizapp-stdlib/codegen/datatypes.py

320 lines
7.8 KiB
Python

from enum import Enum
from pathlib import Path
from typing import Optional, Union
from dataclasses import dataclass, field
class CType(Enum):
VOID = "void"
BOOL = "bool"
CHAR = "char"
C8 = "c8"
C16 = "c16"
C32 = "c32"
I8 = "i8"
I16 = "i16"
I32 = "i32"
I64 = "i64"
U8 = "u8"
U16 = "u16"
U32 = "u32"
U64 = "u64"
F32 = "f32"
F64 = "f64"
F128 = "f128"
IPTR = "iptr"
UPTR = "uptr"
def __str__(self) -> str:
return self.value
class CQualifier(Enum):
NONE = ""
CONST = "const "
EXTERNAL = "external "
INTERNAL = "internal "
PERSISTENT = "persistent "
def __str__(self) -> str:
return self.value
class CPointerType(Enum):
NONE = ""
SINGLE = "*"
DOUBLE = "**"
def __str__(self) -> str:
return self.value
@dataclass
class CPointer:
_type: CPointerType = CPointerType.NONE
qualifier: CQualifier = CQualifier.NONE
def __str__(self) -> str:
return str(self._type) + str(self.qualifier)
@dataclass
class CEnumVal:
name: str
value: Optional[int] = None
def __str__(self) -> str:
return self.name + "" if self.value is None else f" = {self.value}"
@dataclass
class CEnum:
name: str
values: list[CEnumVal]
typedef: bool = False
def __str__(self) -> str:
if self.typedef:
header = "typedef enum {\n"
footer = f"}} {self.name};\n"
else:
header = f"enum {self.name} {{\n"
footer = "};\n"
values = ""
for value in self.values:
values += f" {str(value)},\n"
return header + values + footer
@dataclass
class CStruct:
name: str
cargs: list["CArg"]
typedef_name: str | None = None
def __str__(self) -> str:
return self.declare() + self.define()
def declare(self) -> str:
declaration = f"typedef struct {self.name} {self.typedef_name if self.typedef_name is not None else self.name};\n"
return declaration
def define(self):
definition = f"struct {self.name} {{\n"
args = ""
for arg in self.cargs:
args += f" {str(arg)};\n"
footer = "};\n"
return definition + args + footer;
CUserType = Union[CStruct, CEnum]
CDataType = Union[CType, CUserType, str]
@dataclass
class CArg:
name: str
_type: CDataType
array: bool = False
pointer: CPointer = field(default_factory=CPointer)
qualifier: CQualifier = CQualifier.NONE
def __str__(self) -> str:
qualifier = str(self.qualifier)
_type = get_datatype_string(self._type) + " "
pointer = str(self.pointer)
array = "[]" if self.array else ""
return qualifier + _type + pointer + self.name + array
@dataclass
class CFunc:
name: str
ret_type: CDataType
args: list[CArg]
body: str
pointer: CPointer = field(default_factory=CPointer)
qualifiers: list[CQualifier] = field(default_factory=list)
def __str__(self) -> str:
qualifiers = ""
for qualifier in self.qualifiers:
if qualifier == CQualifier.NONE:
continue
if len(qualifiers) > 0:
qualifiers += " "
qualifiers += f"{str(qualifier)}"
args = ""
for i, arg in enumerate(self.args):
args += f"{str(arg)}"
if i + 1 < len(self.args):
args += ", "
return qualifiers + get_datatype_string(self.ret_type) + " " + str(self.pointer) + self.name + f"({args})"
def declare(self) -> str:
return f"{str(self)};\n"
def define(self) -> str:
return f"{str(self)} {{\n{self.body}\n}}\n\n"
@dataclass
class CInclude:
header: Union[str, "CHeader"]
local: bool = False
same_dir: bool = False
def __str__(self) -> str:
if isinstance(self.header, CHeader):
name = f"{self.header.name}.{self.header.extension}"
else:
name = self.header
if self.local:
open_symbol = '"'
close_symbol = '"'
if self.same_dir:
name = f"./{name}"
else:
open_symbol = '<'
close_symbol = '>'
return f"#include {open_symbol}{name}{close_symbol}\n"
@dataclass
class CFile:
name: str
extension: str
decl_types: list[CStruct] = field(default_factory=list)
def save(self, output_dir: Path):
output_file = output_dir / f"{self.name}.{self.extension}"
with open(output_file, "w+") as outfile:
outfile.write(str(self))
def __str__(self) -> str:
return """\
/**
* THIS FILE IS AUTOMATICALLY GENERATED. ANY MODIFICATIONS TO IT WILL BE OVERWRITTEN
*/
"""
@dataclass
class CHeader(CFile):
extension: str = "h"
includes: list[CInclude] = field(default_factory=list)
types: list[CUserType] = field(default_factory=list)
funcs: list[CFunc] = field(default_factory=list)
def __str__(self) -> str:
name_upper = self.name.upper()
header_guard_name = f"{name_upper}_H"
header_guard_open = f"#ifndef {header_guard_name}\n#define {header_guard_name}\n\n"
header_guard_close = f"#endif // !{header_guard_name}\n"
c_linkage_open = "#ifdef __cplusplus\nBEGIN_C_LINKAGE\n#endif // !__cplusplus\n\n"
c_linkage_close = "\n#ifdef __cplusplus\nEND_C_LINKAGE\n#endif // !__cplusplus\n\n"
includes = _get_includes_string(self.includes)
forward_declarations = ""
for _type in self.decl_types:
forward_declarations += _type.declare()
if len(forward_declarations) > 0:
forward_declarations += "\n"
types = ""
for _type in self.types:
types += str(_type) + "\n"
funcs = ""
for func in self.funcs:
funcs += func.declare()
return (
super().__str__() +
header_guard_open +
includes +
c_linkage_open +
forward_declarations +
types +
funcs +
c_linkage_close +
header_guard_close
)
@dataclass
class CSource(CFile):
extension: str = "c"
includes: list[CInclude] = field(default_factory=list)
types: list[CUserType] = field(default_factory=list)
internal_funcs: list[CFunc] = field(default_factory=list)
funcs: list[CFunc] = field(default_factory=list)
def __str__(self) -> str:
includes = _get_includes_string(self.includes)
forward_declarations = ""
for _type in self.decl_types:
forward_declarations += _type.declare()
if len(forward_declarations) > 0:
forward_declarations += "\n"
types = ""
for _type in self.types:
types += str(_type) + "\n"
internal_funcs_decl = ""
internal_funcs_def = ""
for func in self.internal_funcs:
internal_funcs_decl += func.declare()
internal_funcs_def += func.define()
if len(internal_funcs_decl) > 0:
internal_funcs_decl += "\n"
funcs = ""
for func in self.funcs:
funcs += func.define()
return (
super().__str__() +
includes +
forward_declarations +
types +
internal_funcs_decl +
funcs +
internal_funcs_def
)
def get_datatype_string(_type: CDataType) -> str:
if isinstance(_type, CType):
return str(_type)
elif isinstance(_type, CStruct) or isinstance(_type, CEnum):
return _type.name
elif isinstance(_type, str):
return _type
def _get_includes_string(includes: list[CInclude]) -> str:
output = ""
for include in sorted(includes, key=lambda inc: inc.local, reverse=True):
output += str(include)
if len(output) > 0:
output += "\n"
return output