From 2bc4d3bd5e2e412db7410aa27bf30196168ac356 Mon Sep 17 00:00:00 2001 From: Abdelrahman Date: Sat, 8 Feb 2025 16:09:08 +0000 Subject: [PATCH] Improve datatypes codegen --- codegen/datatypes.py | 169 +++++++++++++++++++++---------------------- 1 file changed, 82 insertions(+), 87 deletions(-) diff --git a/codegen/datatypes.py b/codegen/datatypes.py index 35688a1..b7168ce 100644 --- a/codegen/datatypes.py +++ b/codegen/datatypes.py @@ -5,25 +5,25 @@ from dataclasses import dataclass, field class CType(Enum): - VOID = "void " - BOOL = "bool " - CHAR = "char " - C8 = "c8 " - C16 = "c16 " - C32 = "c32 " - I8 = "i8 " - I16 = "i16 " - I32 = "i32 " - I64 = "i64 " - U8 = "u8 " - U16 = "u16 " - U32 = "u32 " - U64 = "u64 " - F32 = "f32 " - F64 = "f64 " - F128 = "f128 " - IPTR = "iptr " - UPTR = "uptr " + VOID = "void" + BOOL = "bool" + CHAR = "char" + C8 = "c8" + C16 = "c16" + C32 = "c32" + I8 = "i8" + I16 = "i16" + I32 = "i32" + I64 = "i64" + U8 = "u8" + U16 = "u16" + U32 = "u32" + U64 = "u64" + F32 = "f32" + F64 = "f64" + F128 = "f128" + IPTR = "iptr" + UPTR = "uptr" def __str__(self) -> str: return self.value @@ -73,16 +73,41 @@ class CEnum: values: list[CEnumVal] typedef: bool = False + def __str__(self) -> str: + if self.typedef: + header = "typedef enum {\n" + footer = f"}} {self.name};\n" + else: + header = f"enum {self.name} {{\n" + footer = "};\n" + + values = "" + for value in self.values: + values += f" {str(value)},\n" + + return header + values + footer + @dataclass class CStruct: name: str cargs: list["CArg"] + def __str__(self) -> str: + typedef = f"typedef struct {self.name} {self.name};\n" + header = f"struct {self.name} {{\n" + args = "" + for arg in self.cargs: + args += f" {str(arg)};\n" + footer = "};\n" + + return typedef + header + args + footer; + CUserType = Union[CStruct, CEnum] CDataType = Union[CType, CUserType] + @dataclass class CArg: name: str @@ -93,7 +118,7 @@ class CArg: def __str__(self) -> str: qualifier = str(self.qualifier) - _type = get_arg_type_string(self._type) + _type = _get_arg_type_string(self._type) pointer = str(self.pointer) array = "[]" if self.array else "" @@ -122,13 +147,13 @@ class CFunc: if i + 1 < len(self.args): args += ", " - return qualifiers + str(self.ret_type) + str(self.pointer) + self.name + f"({args})" + return qualifiers + str(self.ret_type) + " " + str(self.pointer) + self.name + f"({args})" def declare(self) -> str: return f"{str(self)};\n" def define(self) -> str: - return f"{str(self)} {{\n{self.body}\n}}\n" + return f"{str(self)} {{\n{self.body}\n}}\n\n" @dataclass @@ -136,13 +161,29 @@ class CInclude: header: Union[str, "CHeader"] local: bool = False + def __str__(self) -> str: + if isinstance(self.header, CHeader): + name = f"{self.header.name}.{self.header.extension}" + else: + name = self.header + + if self.local: + open_symbol = '"' + close_symbol = '"' + else: + open_symbol = '<' + close_symbol = '>' + + return f"#include {open_symbol}{name}{close_symbol}\n" + @dataclass class CFile: name: str + extension: str def save(self, output_dir: Path): - output_file = output_dir / self.name + output_file = output_dir / f"{self.name}.{self.extension}" with open(output_file, "w+") as outfile: outfile.write(str(self)) @@ -152,18 +193,19 @@ class CFile: @dataclass class CHeader(CFile): - includes: list[CInclude] - types: list[CUserType] - funcs: list[CFunc] + extension: str = "h" + includes: list[CInclude] = field(default_factory=list) + types: list[CUserType] = field(default_factory=list) + funcs: list[CFunc] = field(default_factory=list) def __str__(self) -> str: pragma = "#pragma once\n\n" - includes = get_includes_string(self.includes) + includes = _get_includes_string(self.includes) types = "" for _type in self.types: - types += get_user_type_string(_type) + types += str(_type) + "\n" funcs = "" for func in self.funcs: @@ -174,17 +216,18 @@ class CHeader(CFile): @dataclass class CSource(CFile): - includes: list[CInclude] - types: list[CUserType] - internal_funcs: list[CFunc] - funcs: list[CFunc] + extension: str = "c" + includes: list[CInclude] = field(default_factory=list) + types: list[CUserType] = field(default_factory=list) + internal_funcs: list[CFunc] = field(default_factory=list) + funcs: list[CFunc] = field(default_factory=list) def __str__(self) -> str: - includes = get_includes_string(self.includes) + includes = _get_includes_string(self.includes) types = "" for _type in self.types: - types += get_user_type_string(_type) + types += str(_type) + "\n" internal_funcs_decl = "" internal_funcs_def = "" @@ -199,65 +242,17 @@ class CSource(CFile): return includes + types + internal_funcs_decl + funcs + internal_funcs_def -def get_user_type_string(_type: Union[CStruct, CEnum]) -> str: - type_str = "" - if isinstance(_type, CStruct): - type_str += cstruct_to_string(_type) + "\n" - else: - type_str += cenum_to_string(_type) + "\n" - - return type_str - - -def cenum_to_string(cenum: CEnum) -> str: - if cenum.typedef: - header = "typedef enum {\n" - footer = f"}} {cenum.name};\n" - else: - header = f"enum {cenum.name} {{\n" - footer = "};\n" - - values = "" - for value in cenum.values: - values += f" {str(value)},\n" - - return header + values + footer - - -def cstruct_to_string(cstruct: CStruct) -> str: - typedef = f"typedef struct {cstruct.name} {cstruct.name};\n" - header = f"struct {cstruct.name} {{\n" - args = "" - for arg in cstruct.cargs: - args += f" {str(arg)};\n" - footer = "};\n" - - return typedef + header + args + footer; - - -def get_arg_type_string(_type: CDataType) -> str: +def _get_arg_type_string(_type: CDataType) -> str: if isinstance(_type, CType): - return str(_type) + return str(_type) + " " elif isinstance(_type, CStruct) or isinstance(_type, CEnum): - return _type.name + return _type.name + " " -def get_includes_string(includes: list[CInclude]) -> str: +def _get_includes_string(includes: list[CInclude]) -> str: output = "" for include in includes: - if isinstance(include.header, CHeader): - name = include.header.name - else: - name = include.header - - if include.local: - open_symbol = '"' - close_symbol = '"' - else: - open_symbol = '<' - close_symbol = '>' - - output += f"#include {open_symbol}{name}{close_symbol}\n" + output += str(include) if len(output) > 0: output += "\n"