Skip to content

Commit

Permalink
add half datatype support for directx backend (#245)
Browse files Browse the repository at this point in the history
* add half datatype support for directx backend

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix merge issue

* Fix again

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Nripesh Niketan <[email protected]>
Co-authored-by: Nripesh Niketan <[email protected]>
  • Loading branch information
4 people authored Jan 4, 2025
1 parent a80766d commit 6f0d2a6
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 31 deletions.
57 changes: 29 additions & 28 deletions crosstl/backend/DirectX/DirectxLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,6 @@
# using sets for faster lookup
SKIP_TOKENS = {"WHITESPACE", "COMMENT_SINGLE", "COMMENT_MULTI"}

# define keywords dictionary
KEYWORDS = {
"struct": "STRUCT",
"cbuffer": "CBUFFER",
"Texture2D": "TEXTURE2D",
"SamplerState": "SAMPLER_STATE",
"float": "FLOAT",
"float2": "FVECTOR",
"float3": "FVECTOR",
"float4": "FVECTOR",
"double": "DOUBLE",
"int": "INT",
"uint": "UINT",
"bool": "BOOL",
"void": "VOID",
"return": "RETURN",
"if": "IF",
"else": "ELSE",
"for": "FOR",
"while": "WHILE",
"do": "DO",
"register": "REGISTER",
"switch": "SWITCH",
"case": "CASE",
"default": "DEFAULT",
"break": "BREAK",
}

# use tuple for immutable token types that won't change
TOKENS = tuple(
[
Expand Down Expand Up @@ -103,10 +75,39 @@
("DEFAULT", r"\bdefault\b"),
("BREAK", r"\bbreak\b"),
("MOD", r"%"),
("HALF", r"\bhalf\b"),
("BITWISE_AND", r"&"),
]
)

KEYWORDS = {
"struct": "STRUCT",
"cbuffer": "CBUFFER",
"Texture2D": "TEXTURE2D",
"SamplerState": "SAMPLER_STATE",
"float": "FLOAT",
"float2": "FVECTOR",
"float3": "FVECTOR",
"float4": "FVECTOR",
"double": "DOUBLE",
"half": "HALF",
"int": "INT",
"uint": "UINT",
"bool": "BOOL",
"void": "VOID",
"return": "RETURN",
"if": "IF",
"else": "ELSE",
"for": "FOR",
"while": "WHILE",
"do": "DO",
"register": "REGISTER",
"switch": "SWITCH",
"case": "CASE",
"default": "DEFAULT",
"break": "BREAK",
}


class HLSLLexer:
def __init__(self, code: str):
Expand Down
15 changes: 12 additions & 3 deletions crosstl/backend/DirectX/DirectxParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def parse_shader(self):
"VOID",
"FLOAT",
"DOUBLE",
"HALF",
"FVECTOR",
"IDENTIFIER",
"TEXTURE2D",
Expand Down Expand Up @@ -198,6 +199,7 @@ def parse_statement(self):
if self.current_token[0] in [
"FLOAT",
"DOUBLE",
"HALF",
"FVECTOR",
"INT",
"UINT",
Expand All @@ -224,6 +226,7 @@ def parse_variable_declaration_or_assignment(self):
if self.current_token[0] in [
"FLOAT",
"DOUBLE",
"HALF",
"FVECTOR",
"INT",
"UINT",
Expand Down Expand Up @@ -355,7 +358,7 @@ def parse_for_statement(self):
self.eat("LPAREN")

# Parse initialization
if self.current_token[0] in ["INT", "FLOAT", "FVECTOR", "DOUBLE"]:
if self.current_token[0] in ["INT", "FLOAT", "FVECTOR", "DOUBLE", "HALF"]:
type_name = self.current_token[1]
self.eat(self.current_token[0])
var_name = self.current_token[1]
Expand Down Expand Up @@ -518,7 +521,13 @@ def parse_unary(self):
return self.parse_primary()

def parse_primary(self):
if self.current_token[0] in ["IDENTIFIER", "FLOAT", "FVECTOR", "DOUBLE"]:
if self.current_token[0] in [
"IDENTIFIER",
"FLOAT",
"FVECTOR",
"DOUBLE",
"HALF",
]:
if self.current_token[0] == "IDENTIFIER":
name = self.current_token[1]
self.eat("IDENTIFIER")
Expand All @@ -527,7 +536,7 @@ def parse_primary(self):
elif self.current_token[0] == "DOT":
return self.parse_member_access(name)
return VariableNode("", name)
if self.current_token[0] in ["FLOAT", "FVECTOR", "DOUBLE"]:
if self.current_token[0] in ["FLOAT", "FVECTOR", "DOUBLE", "HALF"]:
type_name = self.current_token[1]
self.eat(self.current_token[0])
if self.current_token[0] == "LPAREN":
Expand Down
25 changes: 25 additions & 0 deletions tests/test_backend/test_directx/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,5 +627,30 @@ def test_double_dtype_codegen():
pytest.fail("double dtype parsing or code generation not implemented.")


def test_half_dtype_codegen():
code = """
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = float4(0.0, 0.0, 0.0, 1.0);
half value1 = 3.14159; // First half value
half value2 = 2.71828; // Second half value
half result = value1 + value2; // Adding them
if (result > 6.0) {
output.out_color = float4(1.0, 0.0, 0.0, 1.0); // Set color to red
} else {
output.out_color = float4(0.0, 1.0, 0.0, 1.0); // Set color to green
}
return output;
}
"""
try:
tokens = tokenize_code(code)
ast = parse_code(tokens)
generated_code = generate_code(ast)
print(generated_code)
except SyntaxError:
pytest.fail("half dtype parsing or code generation not implemented.")


if __name__ == "__main__":
pytest.main()
22 changes: 22 additions & 0 deletions tests/test_backend/test_directx/test_lexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,27 @@ def test_mod_tokenization():
assert has_mod, "Modulus operator (%) not tokenized correctly"


def test_half_dtype_tokenization():
code = """
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = float4(0.0, 0.0, 0.0, 1.0);
half value1 = 3.14159; // First half value
half value2 = 2.71828; // Second half value
half result = value1 + value2; // Adding them
if (result > 6.0) {
output.out_color = float4(1.0, 0.0, 0.0, 1.0); // Set color to red
} else {
output.out_color = float4(0.0, 1.0, 0.0, 1.0); // Set color to green
}
return output;
}
"""
try:
tokenize_code(code)
except SyntaxError:
pytest.fail("half dtype tokenization is not implemented.")


if __name__ == "__main__":
pytest.main()
23 changes: 23 additions & 0 deletions tests/test_backend/test_directx/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,28 @@ def test_mod_parsing():
pytest.fail("Modulus operator parsing not implemented")


def test_double_dtype_parsing():
code = """
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = float4(0.0, 0.0, 0.0, 1.0);
half value1 = 3.14159; // First half value
half value2 = 2.71828; // Second half value
half result = value1 + value2; // Adding them
if (result > 6.0) {
output.out_color = float4(1.0, 0.0, 0.0, 1.0); // Set color to red
} else {
output.out_color = float4(0.0, 1.0, 0.0, 1.0); // Set color to green
}
return output;
}
"""
try:
tokens = tokenize_code(code)
parse_code(tokens)
except SyntaxError:
pytest.fail("half dtype not implemented.")


if __name__ == "__main__":
pytest.main()

0 comments on commit 6f0d2a6

Please sign in to comment.