From 6f0d2a68198ea398d1ac0f54b4d0d360c121d1c1 Mon Sep 17 00:00:00 2001 From: Yogesh <126279793+themaverick@users.noreply.github.com> Date: Sat, 4 Jan 2025 19:24:19 +0530 Subject: [PATCH] add half datatype support for directx backend (#245) * add half datatype support for directx backend * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix merge issue * Fix again --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nripesh Niketan Co-authored-by: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> --- crosstl/backend/DirectX/DirectxLexer.py | 57 ++++++++++--------- crosstl/backend/DirectX/DirectxParser.py | 15 ++++- .../test_backend/test_directx/test_codegen.py | 25 ++++++++ tests/test_backend/test_directx/test_lexer.py | 22 +++++++ .../test_backend/test_directx/test_parser.py | 23 ++++++++ 5 files changed, 111 insertions(+), 31 deletions(-) diff --git a/crosstl/backend/DirectX/DirectxLexer.py b/crosstl/backend/DirectX/DirectxLexer.py index e6c4c28a..cf6835e0 100644 --- a/crosstl/backend/DirectX/DirectxLexer.py +++ b/crosstl/backend/DirectX/DirectxLexer.py @@ -5,34 +5,6 @@ # using sets for faster lookup SKIP_TOKENS = {"WHITESPACE", "COMMENT_SINGLE", "COMMENT_MULTI"} -# define keywords dictionary -KEYWORDS = { - "struct": "STRUCT", - "cbuffer": "CBUFFER", - "Texture2D": "TEXTURE2D", - "SamplerState": "SAMPLER_STATE", - "float": "FLOAT", - "float2": "FVECTOR", - "float3": "FVECTOR", - "float4": "FVECTOR", - "double": "DOUBLE", - "int": "INT", - "uint": "UINT", - "bool": "BOOL", - "void": "VOID", - "return": "RETURN", - "if": "IF", - "else": "ELSE", - "for": "FOR", - "while": "WHILE", - "do": "DO", - "register": "REGISTER", - "switch": "SWITCH", - "case": "CASE", - "default": "DEFAULT", - "break": "BREAK", -} - # use tuple for immutable token types that won't change TOKENS = tuple( [ @@ -103,10 +75,39 @@ ("DEFAULT", r"\bdefault\b"), ("BREAK", r"\bbreak\b"), ("MOD", r"%"), + ("HALF", r"\bhalf\b"), ("BITWISE_AND", r"&"), ] ) +KEYWORDS = { + "struct": "STRUCT", + "cbuffer": "CBUFFER", + "Texture2D": "TEXTURE2D", + "SamplerState": "SAMPLER_STATE", + "float": "FLOAT", + "float2": "FVECTOR", + "float3": "FVECTOR", + "float4": "FVECTOR", + "double": "DOUBLE", + "half": "HALF", + "int": "INT", + "uint": "UINT", + "bool": "BOOL", + "void": "VOID", + "return": "RETURN", + "if": "IF", + "else": "ELSE", + "for": "FOR", + "while": "WHILE", + "do": "DO", + "register": "REGISTER", + "switch": "SWITCH", + "case": "CASE", + "default": "DEFAULT", + "break": "BREAK", +} + class HLSLLexer: def __init__(self, code: str): diff --git a/crosstl/backend/DirectX/DirectxParser.py b/crosstl/backend/DirectX/DirectxParser.py index 1c469521..15c5aa77 100644 --- a/crosstl/backend/DirectX/DirectxParser.py +++ b/crosstl/backend/DirectX/DirectxParser.py @@ -61,6 +61,7 @@ def parse_shader(self): "VOID", "FLOAT", "DOUBLE", + "HALF", "FVECTOR", "IDENTIFIER", "TEXTURE2D", @@ -198,6 +199,7 @@ def parse_statement(self): if self.current_token[0] in [ "FLOAT", "DOUBLE", + "HALF", "FVECTOR", "INT", "UINT", @@ -224,6 +226,7 @@ def parse_variable_declaration_or_assignment(self): if self.current_token[0] in [ "FLOAT", "DOUBLE", + "HALF", "FVECTOR", "INT", "UINT", @@ -355,7 +358,7 @@ def parse_for_statement(self): self.eat("LPAREN") # Parse initialization - if self.current_token[0] in ["INT", "FLOAT", "FVECTOR", "DOUBLE"]: + if self.current_token[0] in ["INT", "FLOAT", "FVECTOR", "DOUBLE", "HALF"]: type_name = self.current_token[1] self.eat(self.current_token[0]) var_name = self.current_token[1] @@ -518,7 +521,13 @@ def parse_unary(self): return self.parse_primary() def parse_primary(self): - if self.current_token[0] in ["IDENTIFIER", "FLOAT", "FVECTOR", "DOUBLE"]: + if self.current_token[0] in [ + "IDENTIFIER", + "FLOAT", + "FVECTOR", + "DOUBLE", + "HALF", + ]: if self.current_token[0] == "IDENTIFIER": name = self.current_token[1] self.eat("IDENTIFIER") @@ -527,7 +536,7 @@ def parse_primary(self): elif self.current_token[0] == "DOT": return self.parse_member_access(name) return VariableNode("", name) - if self.current_token[0] in ["FLOAT", "FVECTOR", "DOUBLE"]: + if self.current_token[0] in ["FLOAT", "FVECTOR", "DOUBLE", "HALF"]: type_name = self.current_token[1] self.eat(self.current_token[0]) if self.current_token[0] == "LPAREN": diff --git a/tests/test_backend/test_directx/test_codegen.py b/tests/test_backend/test_directx/test_codegen.py index e82bc25c..e2d5cc8a 100644 --- a/tests/test_backend/test_directx/test_codegen.py +++ b/tests/test_backend/test_directx/test_codegen.py @@ -627,5 +627,30 @@ def test_double_dtype_codegen(): pytest.fail("double dtype parsing or code generation not implemented.") +def test_half_dtype_codegen(): + code = """ + PSOutput PSMain(PSInput input) { + PSOutput output; + output.out_color = float4(0.0, 0.0, 0.0, 1.0); + half value1 = 3.14159; // First half value + half value2 = 2.71828; // Second half value + half result = value1 + value2; // Adding them + if (result > 6.0) { + output.out_color = float4(1.0, 0.0, 0.0, 1.0); // Set color to red + } else { + output.out_color = float4(0.0, 1.0, 0.0, 1.0); // Set color to green + } + return output; + } + """ + try: + tokens = tokenize_code(code) + ast = parse_code(tokens) + generated_code = generate_code(ast) + print(generated_code) + except SyntaxError: + pytest.fail("half dtype parsing or code generation not implemented.") + + if __name__ == "__main__": pytest.main() diff --git a/tests/test_backend/test_directx/test_lexer.py b/tests/test_backend/test_directx/test_lexer.py index 661ec9dd..02f78a2a 100644 --- a/tests/test_backend/test_directx/test_lexer.py +++ b/tests/test_backend/test_directx/test_lexer.py @@ -277,5 +277,27 @@ def test_mod_tokenization(): assert has_mod, "Modulus operator (%) not tokenized correctly" +def test_half_dtype_tokenization(): + code = """ + PSOutput PSMain(PSInput input) { + PSOutput output; + output.out_color = float4(0.0, 0.0, 0.0, 1.0); + half value1 = 3.14159; // First half value + half value2 = 2.71828; // Second half value + half result = value1 + value2; // Adding them + if (result > 6.0) { + output.out_color = float4(1.0, 0.0, 0.0, 1.0); // Set color to red + } else { + output.out_color = float4(0.0, 1.0, 0.0, 1.0); // Set color to green + } + return output; + } + """ + try: + tokenize_code(code) + except SyntaxError: + pytest.fail("half dtype tokenization is not implemented.") + + if __name__ == "__main__": pytest.main() diff --git a/tests/test_backend/test_directx/test_parser.py b/tests/test_backend/test_directx/test_parser.py index abad0b42..7af3a86f 100644 --- a/tests/test_backend/test_directx/test_parser.py +++ b/tests/test_backend/test_directx/test_parser.py @@ -392,5 +392,28 @@ def test_mod_parsing(): pytest.fail("Modulus operator parsing not implemented") +def test_double_dtype_parsing(): + code = """ + PSOutput PSMain(PSInput input) { + PSOutput output; + output.out_color = float4(0.0, 0.0, 0.0, 1.0); + half value1 = 3.14159; // First half value + half value2 = 2.71828; // Second half value + half result = value1 + value2; // Adding them + if (result > 6.0) { + output.out_color = float4(1.0, 0.0, 0.0, 1.0); // Set color to red + } else { + output.out_color = float4(0.0, 1.0, 0.0, 1.0); // Set color to green + } + return output; + } + """ + try: + tokens = tokenize_code(code) + parse_code(tokens) + except SyntaxError: + pytest.fail("half dtype not implemented.") + + if __name__ == "__main__": pytest.main()