Skip to content

Commit

Permalink
Added support for #include (#220)
Browse files Browse the repository at this point in the history
* feat: added support for #input

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
plon-Susk7 and pre-commit-ci[bot] authored Dec 26, 2024
1 parent 4dec777 commit 6fce0a1
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 0 deletions.
11 changes: 11 additions & 0 deletions crosstl/backend/DirectX/DirectxAst.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,14 @@ def __repr__(self):

def __str__(self):
return f"({self.op}{self.operand})"


class IncludeNode(ASTNode):
def __init__(self, path):
self.path = path

def __repr__(self):
return f"IncludeNode(path={self.path})"

def __str__(self):
return f"#include {self.path}"
2 changes: 2 additions & 0 deletions crosstl/backend/DirectX/DirectxCrossGLCodeGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def generate(self, ast):
for member in node.members:
code += f" {self.map_type(member.vtype)} {member.name} {self.map_semantic(member.semantic)};\n"
code += " }\n"
elif isinstance(node, IncludeNode):
code += f" #include {node.path}\n"
# Generate global variables
for node in ast.global_variables:
code += f" {self.map_type(node.vtype)} {node.name};\n"
Expand Down
4 changes: 4 additions & 0 deletions crosstl/backend/DirectX/DirectxLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
TOKENS = [
("COMMENT_SINGLE", r"//.*"),
("COMMENT_MULTI", r"/\*[\s\S]*?\*/"),
("INCLUDE", r"\#include\b"),
("STRUCT", r"\bstruct\b"),
("CBUFFER", r"\bcbuffer\b"),
("TEXTURE2D", r"\bTexture2D\b"),
Expand Down Expand Up @@ -60,6 +61,7 @@
("MINUS", r"-"),
("EQUALS", r"="),
("WHITESPACE", r"\s+"),
("STRING", r"\"[^\"]*\""), # Added for string literals
]

KEYWORDS = {
Expand Down Expand Up @@ -95,8 +97,10 @@ def tokenize(self):
pos = 0
while pos < len(self.code):
match = None

for token_type, pattern in TOKENS:
regex = re.compile(pattern)

match = regex.match(self.code, pos)
if match:
text = match.group(0)
Expand Down
10 changes: 10 additions & 0 deletions crosstl/backend/DirectX/DirectxParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
UnaryOpNode,
VariableNode,
VectorConstructorNode,
IncludeNode,
)
from .DirectxLexer import HLSLLexer

Expand Down Expand Up @@ -66,11 +67,20 @@ def parse_shader(self):
functions.append(self.parse_function())
else:
global_variables.append(self.parse_global_variable())
elif self.current_token[0] == "INCLUDE":
structs.append(self.parse_include())

else:
self.eat(self.current_token[0]) # Skip unknown tokens

return ShaderNode(structs, functions, global_variables, cbuffers)

def parse_include(self):
self.eat("INCLUDE")
path = self.current_token[1]
self.eat("STRING")
return IncludeNode(path)

def is_function(self):
current_pos = self.pos
while self.tokens[current_pos][0] != "EOF":
Expand Down
41 changes: 41 additions & 0 deletions tests/test_backend/test_directx/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,5 +498,46 @@ def test_bitwise_ops_codgen():
pytest.fail("bitwise_op parsing or codegen not implemented.")


def test_include_codegen():
code = """
#include "common.hlsl"
struct VSInput {
float4 position : POSITION;
float4 color : TEXCOORD0;
};
struct VSOutput {
float4 out_position : TEXCOORD0;
};
VSOutput VSMain(VSInput input) {
VSOutput output;
output.out_position = input.position;
return output;
}
struct PSInput {
float4 in_position : TEXCOORD0;
};
struct PSOutput {
float4 out_color : SV_TARGET0;
};
PSOutput PSMain(PSInput input) {
PSOutput output;
output.out_color = input.in_position;
return output;
}
"""
try:
tokens = tokenize_code(code)
ast = parse_code(tokens)
generated_code = generate_code(ast)
print(generated_code)
except SyntaxError:
pytest.fail("Include statement failed to parse or generate code.")


if __name__ == "__main__":
pytest.main()

0 comments on commit 6fce0a1

Please sign in to comment.