Improve datatypes codegen

This commit is contained in:
Abdelrahman Said 2025-02-08 16:09:08 +00:00
parent a8f989b76d
commit 2bc4d3bd5e

View File

@ -5,25 +5,25 @@ from dataclasses import dataclass, field
class CType(Enum): class CType(Enum):
VOID = "void " VOID = "void"
BOOL = "bool " BOOL = "bool"
CHAR = "char " CHAR = "char"
C8 = "c8 " C8 = "c8"
C16 = "c16 " C16 = "c16"
C32 = "c32 " C32 = "c32"
I8 = "i8 " I8 = "i8"
I16 = "i16 " I16 = "i16"
I32 = "i32 " I32 = "i32"
I64 = "i64 " I64 = "i64"
U8 = "u8 " U8 = "u8"
U16 = "u16 " U16 = "u16"
U32 = "u32 " U32 = "u32"
U64 = "u64 " U64 = "u64"
F32 = "f32 " F32 = "f32"
F64 = "f64 " F64 = "f64"
F128 = "f128 " F128 = "f128"
IPTR = "iptr " IPTR = "iptr"
UPTR = "uptr " UPTR = "uptr"
def __str__(self) -> str: def __str__(self) -> str:
return self.value return self.value
@ -73,16 +73,41 @@ class CEnum:
values: list[CEnumVal] values: list[CEnumVal]
typedef: bool = False typedef: bool = False
def __str__(self) -> str:
if self.typedef:
header = "typedef enum {\n"
footer = f"}} {self.name};\n"
else:
header = f"enum {self.name} {{\n"
footer = "};\n"
values = ""
for value in self.values:
values += f" {str(value)},\n"
return header + values + footer
@dataclass @dataclass
class CStruct: class CStruct:
name: str name: str
cargs: list["CArg"] cargs: list["CArg"]
def __str__(self) -> str:
typedef = f"typedef struct {self.name} {self.name};\n"
header = f"struct {self.name} {{\n"
args = ""
for arg in self.cargs:
args += f" {str(arg)};\n"
footer = "};\n"
return typedef + header + args + footer;
CUserType = Union[CStruct, CEnum] CUserType = Union[CStruct, CEnum]
CDataType = Union[CType, CUserType] CDataType = Union[CType, CUserType]
@dataclass @dataclass
class CArg: class CArg:
name: str name: str
@ -93,7 +118,7 @@ class CArg:
def __str__(self) -> str: def __str__(self) -> str:
qualifier = str(self.qualifier) qualifier = str(self.qualifier)
_type = get_arg_type_string(self._type) _type = _get_arg_type_string(self._type)
pointer = str(self.pointer) pointer = str(self.pointer)
array = "[]" if self.array else "" array = "[]" if self.array else ""
@ -122,13 +147,13 @@ class CFunc:
if i + 1 < len(self.args): if i + 1 < len(self.args):
args += ", " args += ", "
return qualifiers + str(self.ret_type) + str(self.pointer) + self.name + f"({args})" return qualifiers + str(self.ret_type) + " " + str(self.pointer) + self.name + f"({args})"
def declare(self) -> str: def declare(self) -> str:
return f"{str(self)};\n" return f"{str(self)};\n"
def define(self) -> str: def define(self) -> str:
return f"{str(self)} {{\n{self.body}\n}}\n" return f"{str(self)} {{\n{self.body}\n}}\n\n"
@dataclass @dataclass
@ -136,13 +161,29 @@ class CInclude:
header: Union[str, "CHeader"] header: Union[str, "CHeader"]
local: bool = False local: bool = False
def __str__(self) -> str:
if isinstance(self.header, CHeader):
name = f"{self.header.name}.{self.header.extension}"
else:
name = self.header
if self.local:
open_symbol = '"'
close_symbol = '"'
else:
open_symbol = '<'
close_symbol = '>'
return f"#include {open_symbol}{name}{close_symbol}\n"
@dataclass @dataclass
class CFile: class CFile:
name: str name: str
extension: str
def save(self, output_dir: Path): def save(self, output_dir: Path):
output_file = output_dir / self.name output_file = output_dir / f"{self.name}.{self.extension}"
with open(output_file, "w+") as outfile: with open(output_file, "w+") as outfile:
outfile.write(str(self)) outfile.write(str(self))
@ -152,18 +193,19 @@ class CFile:
@dataclass @dataclass
class CHeader(CFile): class CHeader(CFile):
includes: list[CInclude] extension: str = "h"
types: list[CUserType] includes: list[CInclude] = field(default_factory=list)
funcs: list[CFunc] types: list[CUserType] = field(default_factory=list)
funcs: list[CFunc] = field(default_factory=list)
def __str__(self) -> str: def __str__(self) -> str:
pragma = "#pragma once\n\n" pragma = "#pragma once\n\n"
includes = get_includes_string(self.includes) includes = _get_includes_string(self.includes)
types = "" types = ""
for _type in self.types: for _type in self.types:
types += get_user_type_string(_type) types += str(_type) + "\n"
funcs = "" funcs = ""
for func in self.funcs: for func in self.funcs:
@ -174,17 +216,18 @@ class CHeader(CFile):
@dataclass @dataclass
class CSource(CFile): class CSource(CFile):
includes: list[CInclude] extension: str = "c"
types: list[CUserType] includes: list[CInclude] = field(default_factory=list)
internal_funcs: list[CFunc] types: list[CUserType] = field(default_factory=list)
funcs: list[CFunc] internal_funcs: list[CFunc] = field(default_factory=list)
funcs: list[CFunc] = field(default_factory=list)
def __str__(self) -> str: def __str__(self) -> str:
includes = get_includes_string(self.includes) includes = _get_includes_string(self.includes)
types = "" types = ""
for _type in self.types: for _type in self.types:
types += get_user_type_string(_type) types += str(_type) + "\n"
internal_funcs_decl = "" internal_funcs_decl = ""
internal_funcs_def = "" internal_funcs_def = ""
@ -199,65 +242,17 @@ class CSource(CFile):
return includes + types + internal_funcs_decl + funcs + internal_funcs_def return includes + types + internal_funcs_decl + funcs + internal_funcs_def
def get_user_type_string(_type: Union[CStruct, CEnum]) -> str: def _get_arg_type_string(_type: CDataType) -> str:
type_str = ""
if isinstance(_type, CStruct):
type_str += cstruct_to_string(_type) + "\n"
else:
type_str += cenum_to_string(_type) + "\n"
return type_str
def cenum_to_string(cenum: CEnum) -> str:
if cenum.typedef:
header = "typedef enum {\n"
footer = f"}} {cenum.name};\n"
else:
header = f"enum {cenum.name} {{\n"
footer = "};\n"
values = ""
for value in cenum.values:
values += f" {str(value)},\n"
return header + values + footer
def cstruct_to_string(cstruct: CStruct) -> str:
typedef = f"typedef struct {cstruct.name} {cstruct.name};\n"
header = f"struct {cstruct.name} {{\n"
args = ""
for arg in cstruct.cargs:
args += f" {str(arg)};\n"
footer = "};\n"
return typedef + header + args + footer;
def get_arg_type_string(_type: CDataType) -> str:
if isinstance(_type, CType): if isinstance(_type, CType):
return str(_type) return str(_type) + " "
elif isinstance(_type, CStruct) or isinstance(_type, CEnum): elif isinstance(_type, CStruct) or isinstance(_type, CEnum):
return _type.name return _type.name + " "
def get_includes_string(includes: list[CInclude]) -> str: def _get_includes_string(includes: list[CInclude]) -> str:
output = "" output = ""
for include in includes: for include in includes:
if isinstance(include.header, CHeader): output += str(include)
name = include.header.name
else:
name = include.header
if include.local:
open_symbol = '"'
close_symbol = '"'
else:
open_symbol = '<'
close_symbol = '>'
output += f"#include {open_symbol}{name}{close_symbol}\n"
if len(output) > 0: if len(output) > 0:
output += "\n" output += "\n"