Skip to content

Commit

Permalink
fix:ast PR
Browse files Browse the repository at this point in the history
  • Loading branch information
samthakur587 committed Jul 31, 2024
1 parent 3df4ad9 commit 53f539c
Show file tree
Hide file tree
Showing 10 changed files with 214 additions and 250 deletions.
4 changes: 2 additions & 2 deletions crosstl/src/backend/DirectX/DirectxCrossGLCodeGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def generate_expression(self, expr, is_main=False):
right = self.generate_expression(expr.right, is_main)
return f"({left} {expr.op} {right})"
elif isinstance(expr, UnaryOpNode):
operand = self.generate_expression(expr.operand, is_main)
return f"({expr.operator}{operand})"
operand = self.generate_expression(expr.op, is_main)
return f"({expr.op}{operand})"
elif isinstance(expr, FunctionCallNode):
args = ", ".join(
self.generate_expression(arg, is_main) for arg in expr.args
Expand Down
36 changes: 2 additions & 34 deletions crosstl/src/backend/Opengl/OpenglLexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
("INT", r"int"),
("SAMPLER2D", r"sampler2D"),
("PRE_INCREMENT", r"\+\+(?=\w)"), # Lookahead to match pre-increment
("PRE_DECREMENT", r"--(?=\w)"), # Lookahead to match pre-decrement
("PRE_DECREMENT", r"--(?=\w)"), # Lookahead to match pre-decrement
("POST_INCREMENT", r"(?<=\w)\+\+"), # Lookbehind to match post-increment
("POST_DECREMENT", r"(?<=\w)--"), # Lookbehind to match post-decrement
("POST_DECREMENT", r"(?<=\w)--"), # Lookbehind to match post-decrement
("IDENTIFIER", r"[a-zA-Z_][a-zA-Z_0-9]*"),
("LBRACE", r"\{"),
("RBRACE", r"\}"),
Expand Down Expand Up @@ -116,35 +116,3 @@ def tokenize(self):
)

self.tokens.append(("EOF", None)) # End of file token


if __name__ == "__main__":
code = """
#version 330 core
// Vertex Shader
layout (location = 0) in vec3 position;
layout (location = 1) in vec2 texCoord;
out vec2 fragTexCoord;
void main() {
gl_Position = vec4(position, 1.0);
fragTexCoord = texCoord;
}
// Fragment Shader
in vec2 fragTexCoord;
out vec4 color;
uniform sampler2D textureSampler;
void main() {
color = texture(textureSampler, fragTexCoord);
}
"""
lexer = Lexer(code)
for token in lexer.tokens:
print(token)
137 changes: 60 additions & 77 deletions crosstl/src/backend/Opengl/OpenglParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
)
from .OpenglLexer import Lexer


class Parser:
def __init__(self, tokens):
self.tokens = tokens
self.pos = 0
self.current_token = self.tokens[self.pos]



def skip_comments(self):
while self.current_token[0] in ["COMMENT_MULTI"]:
self.eat(self.current_token[0])
Expand Down Expand Up @@ -62,41 +61,59 @@ def parse_version_directive(self):
self.eat("CORE")
return VersionDirectiveNode(number, version_identifier)
else:
raise SyntaxError(f"Expected NUMBER after VERSION, got {self.current_token[0]}")
raise SyntaxError(
f"Expected NUMBER after VERSION, got {self.current_token[0]}"
)
else:
raise SyntaxError(f"Expected VERSION directive, got {self.current_token[0]}")
raise SyntaxError(
f"Expected VERSION directive, got {self.current_token[0]}"
)

def parse_layout(self, current_section):
self.eat("LAYOUT")
self.eat("LPAREN")

if self.current_token[0] == "IDENTIFIER" and self.current_token[1] == "location":

if (
self.current_token[0] == "IDENTIFIER"
and self.current_token[1] == "location"
):
self.eat("IDENTIFIER")
self.eat("EQUALS")
location_number = self.current_token[1]
self.eat("NUMBER")

self.eat("RPAREN")
self.skip_comments()

if self.current_token[0] == "IN":
self.eat("IN")
dtype = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return LayoutNode(section=current_section, location_number=location_number, dtype=dtype, name=name)
return LayoutNode(
section=current_section,
location_number=location_number,
dtype=dtype,
name=name,
)
elif self.current_token[0] == "OUT":
self.eat("OUT")
dtype = self.parse_type()
name = self.current_token[1]
self.eat("IDENTIFIER")
self.eat("SEMICOLON")
return LayoutNode(section=current_section, location_number=location_number, dtype=dtype, name=name)
return LayoutNode(
section=current_section,
location_number=location_number,
dtype=dtype,
name=name,
)
else:
raise SyntaxError("Expected 'IN' or 'OUT' after location in LAYOUT")
else:
raise SyntaxError("Expected IDENTIFIER 'location' in LAYOUT")

def parse_shader(self, version_node):
global_inputs = []
global_outputs = []
Expand All @@ -107,19 +124,16 @@ def parse_shader(self, version_node):

while self.current_token[0] != "EOF":
if self.current_token[0] == "COMMENT_SINGLE":
comment_content = self.current_token[1].strip().lower() # Normalize content
print(f"Comment content: '{comment_content}'")

comment_content = (
self.current_token[1].strip().lower()
) # Normalize content
if "vertex shader" in comment_content:
current_section = "VERTEX"
print("Switched to VERTEX section")
elif "fragment shader" in comment_content:
current_section = "FRAGMENT"
print("Switched to FRAGMENT section")
else:
current_section = "VERTEX"
print("Defaulting to VERTEX section")


self.eat("COMMENT_SINGLE")

if self.current_token[0] == "LAYOUT":
Expand Down Expand Up @@ -153,7 +167,7 @@ def parse_shader(self, version_node):
elif self.current_token[0] == "UNIFORM":
self.skip_comments()
uniforms.extend(self.parse_uniforms())

elif self.current_token[0] == "VERSION":
self.parse_version_directive()

Expand Down Expand Up @@ -189,8 +203,8 @@ def parse_shader(self, version_node):
else:
raise SyntaxError(f"Unexpected token {self.current_token[0]}")

#print(f"Final vertex section: {vertex_section}")
#print(f"Final fragment section: {fragment_section}")
# print(f"Final vertex section: {vertex_section}")
# print(f"Final fragment section: {fragment_section}")

return ShaderNode(
version=version_node,
Expand All @@ -199,7 +213,7 @@ def parse_shader(self, version_node):
uniforms=uniforms,
vertex_section=vertex_section,
fragment_section=fragment_section,
functions=[]
functions=[],
)

def parse_shader_section(self, current_section):
Expand All @@ -220,33 +234,34 @@ def parse_shader_section(self, current_section):
elif self.current_token[0] == "IN":
self.skip_comments()
inputs.extend(self.parse_inputs())
#print(f"Inputs collected: {inputs}")
# print(f"Inputs collected: {inputs}")

elif self.current_token[0] == "OUT":
self.skip_comments()
outputs.extend(self.parse_outputs())
#print(f"Outputs collected: {outputs}")
# print(f"Outputs collected: {outputs}")

elif self.current_token[0] == "UNIFORM":
self.skip_comments()
uniforms.extend(self.parse_uniforms())
#print(f"Uniforms collected: {uniforms}")
# print(f"Uniforms collected: {uniforms}")

elif self.current_token[0] in ["VOID", "FLOAT", "VECTOR"]:
self.skip_comments()
functions.append(self.parse_function())
#print(f"Functions collected: {functions}")
# print(f"Functions collected: {functions}")

elif self.current_token[0] == "RBRACE":
self.eat("RBRACE")
return (inputs, outputs, uniforms, layout_qualifiers, functions)

else:
raise SyntaxError(f"Unexpected token {self.current_token[0]} in shader section")
raise SyntaxError(
f"Unexpected token {self.current_token[0]} in shader section"
)

raise SyntaxError("Unexpected end of input in shader section")


def parse_inputs(self):
inputs = []
while self.current_token[0] == "IN":
Expand Down Expand Up @@ -280,7 +295,6 @@ def parse_uniforms(self):
uniforms.append(UniformNode(vtype, name))
return uniforms


def parse_variable(self, type_name):
name = self.current_token[1]
self.eat("IDENTIFIER")
Expand Down Expand Up @@ -321,7 +335,7 @@ def parse_variable(self, type_name):
raise SyntaxError(
f"Unexpected token in variable declaration: {self.current_token[0]}"
)

def parse_assignment_or_function_call(self):
type_name = ""
if self.current_token[0] in ["VECTOR", "FLOAT", "INT", "MATRIX"]:
Expand Down Expand Up @@ -353,7 +367,7 @@ def parse_assignment_or_function_call(self):
raise SyntaxError(
f"Unexpected token after identifier: {self.current_token[0]}"
)

def parse_function_call(self, name):
self.eat("LPAREN")
args = []
Expand All @@ -364,7 +378,7 @@ def parse_function_call(self, name):
args.append(self.parse_expression())
self.eat("RPAREN")
return FunctionCallNode(name, args)

def parse_function(self):
return_type = self.parse_type()
if self.current_token[0] == "MAIN":
Expand All @@ -384,7 +398,7 @@ def parse_function(self):
body = self.parse_body()
self.eat("RBRACE")
return FunctionNode(return_type, fname, params, body)

def parse_body(self):
body = []
while self.current_token[0] not in ["RBRACE", "EOF"]:
Expand Down Expand Up @@ -419,7 +433,14 @@ def parse_type(self):
if self.current_token[0] == "VOID":
self.eat("VOID")
return "void"
elif self.current_token[0] in ["VECTOR", "FLOAT", "INT", "MATRIX", "BOOLEAN","SAMPLER2D"]:
elif self.current_token[0] in [
"VECTOR",
"FLOAT",
"INT",
"MATRIX",
"BOOLEAN",
"SAMPLER2D",
]:
dtype = self.current_token[1]
self.eat(self.current_token[0])
return dtype
Expand Down Expand Up @@ -492,6 +513,7 @@ def parse_assignment(self):
expr = self.parse_expression()
self.eat("SEMICOLON")
return AssignmentNode(var_name, expr)

def parse_function_call_or_identifier(self):
if self.current_token[0] == "VECTOR":
func_name = self.current_token[1]
Expand All @@ -505,7 +527,7 @@ def parse_function_call_or_identifier(self):
elif self.current_token[0] == "DOT":
return self.parse_member_access(func_name)
return VariableNode("", func_name)

def parse_additive(self):
left = self.parse_multiplicative()
while self.current_token[0] in ["PLUS", "MINUS"]:
Expand All @@ -514,7 +536,7 @@ def parse_additive(self):
right = self.parse_multiplicative()
left = BinaryOpNode(left, op, right)
return left

def parse_primary(self):
if self.current_token[0] == "MINUS":
self.eat("MINUS")
Expand Down Expand Up @@ -571,6 +593,7 @@ def parse_expression(self):
left = TernaryOpNode(left, true_expr, false_expr)

return left

def parse_return(self):
self.eat("RETURN")
expr = self.parse_expression()
Expand All @@ -593,7 +616,7 @@ def parse_if(self):
return IfNode(condition, body, else_body)
else:
return IfNode(condition, body)

def parse_for(self):
self.eat("FOR")
self.eat("LPAREN")
Expand All @@ -606,7 +629,7 @@ def parse_for(self):
body = self.parse_body()
self.eat("RBRACE")
return ForNode(init, condition, update, body)

def parse_member_access(self, object):
self.eat("DOT")
if self.current_token[0] != "IDENTIFIER":
Expand All @@ -621,43 +644,3 @@ def parse_member_access(self, object):
return self.parse_member_access(MemberAccessNode(object, member))

return MemberAccessNode(object, member)


if __name__ == "__main__":
code = """
#version 330 core
// Vertex Shader
layout (location = 0) in vec3 position;
layout (location = 1) in vec2 texCoord;
out vec2 fragTexCoord;
void main() {
gl_Position = vec4(position, 1.0);
fragTexCoord = texCoord;
}
// Fragment Shader
in vec2 fragTexCoord;
out vec4 color;
uniform sampler2D textureSampler;
void main() {
color = texture(textureSampler, fragTexCoord);
}
"""
lexer = Lexer(code)
parser = Parser(lexer.tokens)
for token in lexer.tokens:
print(token)
parser = Parser(lexer.tokens)
ast = parser.parse()

print("Parsing completed successfully!")
#print(ast)
#print("Parsed AST:")
#print(parser.parse())
Loading

0 comments on commit 53f539c

Please sign in to comment.